diff --git a/python/packages/autogen-core/src/autogen_core/_queue.py b/python/packages/autogen-core/src/autogen_core/_queue.py new file mode 100644 index 000000000000..699921a37f5d --- /dev/null +++ b/python/packages/autogen-core/src/autogen_core/_queue.py @@ -0,0 +1,264 @@ +# Copy of Asyncio queue: https://github.com/python/cpython/blob/main/Lib/asyncio/queues.py +# So that shutdown can be used in <3.13 +# Modified to work outside of the asyncio package + +import asyncio +import collections +import threading +from typing import Generic, TypeVar + +_global_lock = threading.Lock() + + +class _LoopBoundMixin: + _loop = None + + def _get_loop(self) -> asyncio.AbstractEventLoop: + loop = asyncio.get_running_loop() + + if self._loop is None: + with _global_lock: + if self._loop is None: + self._loop = loop + if loop is not self._loop: + raise RuntimeError(f"{self!r} is bound to a different event loop") + return loop + + +class QueueShutDown(Exception): + """Raised when putting on to or getting from a shut-down Queue.""" + + pass + + +T = TypeVar("T") + + +class Queue(_LoopBoundMixin, Generic[T]): + def __init__(self, maxsize: int = 0): + self._maxsize = maxsize + self._getters = collections.deque[asyncio.Future[None]]() + self._putters = collections.deque[asyncio.Future[None]]() + self._unfinished_tasks = 0 + self._finished = asyncio.Event() + self._finished.set() + self._queue = collections.deque[T]() + self._is_shutdown = False + + # These three are overridable in subclasses. + + def _get(self) -> T: + return self._queue.popleft() + + def _put(self, item: T) -> None: + self._queue.append(item) + + # End of the overridable methods. + + def _wakeup_next(self, waiters: collections.deque[asyncio.Future[None]]) -> None: + # Wake up the next waiter (if any) that isn't cancelled. + while waiters: + waiter = waiters.popleft() + if not waiter.done(): + waiter.set_result(None) + break + + def __repr__(self) -> str: + return f"<{type(self).__name__} at {id(self):#x} {self._format()}>" + + def __str__(self) -> str: + return f"<{type(self).__name__} {self._format()}>" + + def _format(self) -> str: + result = f"maxsize={self._maxsize!r}" + if getattr(self, "_queue", None): + result += f" _queue={list(self._queue)!r}" + if self._getters: + result += f" _getters[{len(self._getters)}]" + if self._putters: + result += f" _putters[{len(self._putters)}]" + if self._unfinished_tasks: + result += f" tasks={self._unfinished_tasks}" + if self._is_shutdown: + result += " shutdown" + return result + + def qsize(self) -> int: + """Number of items in the queue.""" + return len(self._queue) + + @property + def maxsize(self) -> int: + """Number of items allowed in the queue.""" + return self._maxsize + + def empty(self) -> bool: + """Return True if the queue is empty, False otherwise.""" + return not self._queue + + def full(self) -> bool: + """Return True if there are maxsize items in the queue. + + Note: if the Queue was initialized with maxsize=0 (the default), + then full() is never True. + """ + if self._maxsize <= 0: + return False + else: + return self.qsize() >= self._maxsize + + async def put(self, item: T) -> None: + """Put an item into the queue. + + Put an item into the queue. If the queue is full, wait until a free + slot is available before adding item. + + Raises QueueShutDown if the queue has been shut down. + """ + while self.full(): + if self._is_shutdown: + raise QueueShutDown + putter = self._get_loop().create_future() + self._putters.append(putter) + try: + await putter + except: + putter.cancel() # Just in case putter is not done yet. + try: + # Clean self._putters from canceled putters. + self._putters.remove(putter) + except ValueError: + # The putter could be removed from self._putters by a + # previous get_nowait call or a shutdown call. + pass + if not self.full() and not putter.cancelled(): + # We were woken up by get_nowait(), but can't take + # the call. Wake up the next in line. + self._wakeup_next(self._putters) + raise + return self.put_nowait(item) + + def put_nowait(self, item: T) -> None: + """Put an item into the queue without blocking. + + If no free slot is immediately available, raise QueueFull. + + Raises QueueShutDown if the queue has been shut down. + """ + if self._is_shutdown: + raise QueueShutDown + if self.full(): + raise asyncio.QueueFull + self._put(item) + self._unfinished_tasks += 1 + self._finished.clear() + self._wakeup_next(self._getters) + + async def get(self) -> T: + """Remove and return an item from the queue. + + If queue is empty, wait until an item is available. + + Raises QueueShutDown if the queue has been shut down and is empty, or + if the queue has been shut down immediately. + """ + while self.empty(): + if self._is_shutdown and self.empty(): + raise QueueShutDown + getter = self._get_loop().create_future() + self._getters.append(getter) + try: + await getter + except: + getter.cancel() # Just in case getter is not done yet. + try: + # Clean self._getters from canceled getters. + self._getters.remove(getter) + except ValueError: + # The getter could be removed from self._getters by a + # previous put_nowait call, or a shutdown call. + pass + if not self.empty() and not getter.cancelled(): + # We were woken up by put_nowait(), but can't take + # the call. Wake up the next in line. + self._wakeup_next(self._getters) + raise + return self.get_nowait() + + def get_nowait(self) -> T: + """Remove and return an item from the queue. + + Return an item if one is immediately available, else raise QueueEmpty. + + Raises QueueShutDown if the queue has been shut down and is empty, or + if the queue has been shut down immediately. + """ + if self.empty(): + if self._is_shutdown: + raise QueueShutDown + raise asyncio.QueueEmpty + item = self._get() + self._wakeup_next(self._putters) + return item + + def task_done(self) -> None: + """Indicate that a formerly enqueued task is complete. + + Used by queue consumers. For each get() used to fetch a task, + a subsequent call to task_done() tells the queue that the processing + on the task is complete. + + If a join() is currently blocking, it will resume when all items have + been processed (meaning that a task_done() call was received for every + item that had been put() into the queue). + + shutdown(immediate=True) calls task_done() for each remaining item in + the queue. + + Raises ValueError if called more times than there were items placed in + the queue. + """ + if self._unfinished_tasks <= 0: + raise ValueError("task_done() called too many times") + self._unfinished_tasks -= 1 + if self._unfinished_tasks == 0: + self._finished.set() + + async def join(self) -> None: + """Block until all items in the queue have been gotten and processed. + + The count of unfinished tasks goes up whenever an item is added to the + queue. The count goes down whenever a consumer calls task_done() to + indicate that the item was retrieved and all work on it is complete. + When the count of unfinished tasks drops to zero, join() unblocks. + """ + if self._unfinished_tasks > 0: + await self._finished.wait() + + def shutdown(self, immediate: bool = False) -> None: + """Shut-down the queue, making queue gets and puts raise QueueShutDown. + + By default, gets will only raise once the queue is empty. Set + 'immediate' to True to make gets raise immediately instead. + + All blocked callers of put() and get() will be unblocked. If + 'immediate', a task is marked as done for each item remaining in + the queue, which may unblock callers of join(). + """ + self._is_shutdown = True + if immediate: + while not self.empty(): + self._get() + if self._unfinished_tasks > 0: + self._unfinished_tasks -= 1 + if self._unfinished_tasks == 0: + self._finished.set() + # All getters need to re-check queue-empty to raise ShutDown + while self._getters: + getter = self._getters.popleft() + if not getter.done(): + getter.set_result(None) + while self._putters: + putter = self._putters.popleft() + if not putter.done(): + putter.set_result(None) diff --git a/python/packages/autogen-core/src/autogen_core/_single_threaded_agent_runtime.py b/python/packages/autogen-core/src/autogen_core/_single_threaded_agent_runtime.py index 5141ac9bccec..a00e531fc962 100644 --- a/python/packages/autogen-core/src/autogen_core/_single_threaded_agent_runtime.py +++ b/python/packages/autogen-core/src/autogen_core/_single_threaded_agent_runtime.py @@ -3,17 +3,24 @@ import asyncio import inspect import logging +import sys import threading import uuid import warnings -from asyncio import CancelledError, Future, Task +from asyncio import CancelledError, Future, Queue, Task from collections.abc import Sequence from dataclasses import dataclass -from enum import Enum from typing import Any, Awaitable, Callable, Dict, List, Mapping, ParamSpec, Set, Type, TypeVar, cast from opentelemetry.trace import TracerProvider +if sys.version_info >= (3, 13): + from asyncio import Queue, QueueShutDown +else: + from ._queue import Queue, QueueShutDown # type: ignore + +from typing_extensions import deprecated + from ._agent import Agent from ._agent_id import AgentId from ._agent_instantiation import AgentInstantiationContext @@ -100,48 +107,36 @@ def decrement(self) -> None: class RunContext: - class RunState(Enum): - RUNNING = 0 - CANCELLED = 1 - UNTIL_IDLE = 2 - def __init__(self, runtime: SingleThreadedAgentRuntime) -> None: self._runtime = runtime - self._run_state = RunContext.RunState.RUNNING - self._end_condition: Callable[[], bool] = self._stop_when_cancelled self._run_task = asyncio.create_task(self._run()) - self._lock = asyncio.Lock() + self._stopped = asyncio.Event() async def _run(self) -> None: while True: - async with self._lock: - if self._end_condition(): - return + if self._stopped.is_set(): + return - await self._runtime.process_next() + await self._runtime._process_next() # type: ignore async def stop(self) -> None: - async with self._lock: - self._run_state = RunContext.RunState.CANCELLED - self._end_condition = self._stop_when_cancelled + self._stopped.set() + self._runtime._message_queue.shutdown(immediate=True) # type: ignore await self._run_task async def stop_when_idle(self) -> None: - async with self._lock: - self._run_state = RunContext.RunState.UNTIL_IDLE - self._end_condition = self._stop_when_idle + await self._runtime._message_queue.join() # type: ignore + self._stopped.set() + self._runtime._message_queue.shutdown(immediate=True) # type: ignore await self._run_task - async def stop_when(self, condition: Callable[[], bool]) -> None: - async with self._lock: - self._end_condition = condition - await self._run_task - - def _stop_when_cancelled(self) -> bool: - return self._run_state == RunContext.RunState.CANCELLED + async def stop_when(self, condition: Callable[[], bool], check_period: float = 1.0) -> None: + async def check_condition() -> None: + while not condition(): + await asyncio.sleep(check_period) + await self.stop() - def _stop_when_idle(self) -> bool: - return self._run_state == RunContext.RunState.UNTIL_IDLE and self._runtime.idle + await asyncio.create_task(check_condition()) def _warn_if_none(value: Any, handler_name: str) -> None: @@ -169,28 +164,23 @@ def __init__( tracer_provider: TracerProvider | None = None, ) -> None: self._tracer_helper = TraceHelper(tracer_provider, MessageRuntimeTracingConfig("SingleThreadedAgentRuntime")) - self._message_queue: List[PublishMessageEnvelope | SendMessageEnvelope | ResponseMessageEnvelope] = [] + self._message_queue: Queue[PublishMessageEnvelope | SendMessageEnvelope | ResponseMessageEnvelope] = Queue() # (namespace, type) -> List[AgentId] self._agent_factories: Dict[ str, Callable[[], Agent | Awaitable[Agent]] | Callable[[AgentRuntime, AgentId], Agent | Awaitable[Agent]] ] = {} self._instantiated_agents: Dict[AgentId, Agent] = {} self._intervention_handlers = intervention_handlers - self._outstanding_tasks = Counter() self._background_tasks: Set[Task[Any]] = set() self._subscription_manager = SubscriptionManager() self._run_context: RunContext | None = None self._serialization_registry = SerializationRegistry() @property - def unprocessed_messages( + def unprocessed_messages_count( self, - ) -> Sequence[PublishMessageEnvelope | SendMessageEnvelope | ResponseMessageEnvelope]: - return self._message_queue - - @property - def outstanding_tasks(self) -> int: - return self._outstanding_tasks.get() + ) -> int: + return self._message_queue.qsize() @property def _known_agent_names(self) -> Set[str]: @@ -231,7 +221,7 @@ async def send_message( content = message.__dict__ if hasattr(message, "__dict__") else message logger.info(f"Sending message of type {type(message).__name__} to {recipient.type}: {content}") - self._message_queue.append( + await self._message_queue.put( SendMessageEnvelope( message=message, recipient=recipient, @@ -279,7 +269,7 @@ async def publish_message( # ) # ) - self._message_queue.append( + await self._message_queue.put( PublishMessageEnvelope( message=message, cancellation_token=cancellation_token, @@ -340,14 +330,14 @@ async def _process_send(self, message_envelope: SendMessageEnvelope) -> None: except CancelledError as e: if not message_envelope.future.cancelled(): message_envelope.future.set_exception(e) - self._outstanding_tasks.decrement() + self._message_queue.task_done() return except BaseException as e: message_envelope.future.set_exception(e) - self._outstanding_tasks.decrement() + self._message_queue.task_done() return - self._message_queue.append( + await self._message_queue.put( ResponseMessageEnvelope( message=response, future=message_envelope.future, @@ -356,7 +346,7 @@ async def _process_send(self, message_envelope: SendMessageEnvelope) -> None: metadata=get_telemetry_envelope_metadata(), ) ) - self._outstanding_tasks.decrement() + self._message_queue.task_done() async def _process_publish(self, message_envelope: PublishMessageEnvelope) -> None: with self._tracer_helper.trace_block("publish", message_envelope.topic_id, parent=message_envelope.metadata): @@ -411,7 +401,7 @@ async def _on_message(agent: Agent, message_context: MessageContext) -> Any: return logger.error("Error processing publish message", exc_info=True) finally: - self._outstanding_tasks.decrement() + self._message_queue.task_done() # TODO if responses are given for a publish async def _process_response(self, message_envelope: ResponseMessageEnvelope) -> None: @@ -433,18 +423,21 @@ async def _process_response(self, message_envelope: ResponseMessageEnvelope) -> # delivery_stage=DeliveryStage.DELIVER, # ) # ) - self._outstanding_tasks.decrement() + self._message_queue.task_done() if not message_envelope.future.cancelled(): message_envelope.future.set_result(message_envelope.message) + @deprecated("Manually stepping the runtime processing is deprecated. Use start() instead.") async def process_next(self) -> None: + await self._process_next() + + async def _process_next(self) -> None: """Process the next message in the queue.""" - if len(self._message_queue) == 0: - # Yield control to the event loop to allow other tasks to run - await asyncio.sleep(0) + try: + message_envelope = await self._message_queue.get() + except QueueShutDown: return - message_envelope = self._message_queue.pop(0) match message_envelope: case SendMessageEnvelope(message=message, sender=sender, recipient=recipient, future=future): @@ -464,7 +457,6 @@ async def process_next(self) -> None: return message_envelope.message = temp_message - self._outstanding_tasks.increment() task = asyncio.create_task(self._process_send(message_envelope)) self._background_tasks.add(task) task.add_done_callback(self._background_tasks.discard) @@ -489,7 +481,6 @@ async def process_next(self) -> None: return message_envelope.message = temp_message - self._outstanding_tasks.increment() task = asyncio.create_task(self._process_publish(message_envelope)) self._background_tasks.add(task) task.add_done_callback(self._background_tasks.discard) @@ -507,7 +498,6 @@ async def process_next(self) -> None: future.set_exception(MessageDroppedException()) return message_envelope.message = temp_message - self._outstanding_tasks.increment() task = asyncio.create_task(self._process_response(message_envelope)) self._background_tasks.add(task) task.add_done_callback(self._background_tasks.discard) @@ -515,37 +505,59 @@ async def process_next(self) -> None: # Yield control to the message loop to allow other tasks to run await asyncio.sleep(0) - @property - def idle(self) -> bool: - return len(self._message_queue) == 0 and self._outstanding_tasks.get() == 0 - def start(self) -> None: - """Start the runtime message processing loop.""" + """Start the runtime message processing loop. This runs in a background task. + + Example: + + .. code-block:: python + + from autogen_core import SingleThreadedAgentRuntime + + runtime = SingleThreadedAgentRuntime() + runtime.start() + + """ if self._run_context is not None: raise RuntimeError("Runtime is already started") self._run_context = RunContext(self) async def stop(self) -> None: - """Stop the runtime message processing loop.""" + """Immediately stop the runtime message processing loop. The currently processing message will be completed, but all others following it will be discarded.""" if self._run_context is None: raise RuntimeError("Runtime is not started") await self._run_context.stop() self._run_context = None + self._message_queue = Queue() async def stop_when_idle(self) -> None: """Stop the runtime message processing loop when there is - no outstanding message being processed or queued.""" + no outstanding message being processed or queued. This is the most common way to stop the runtime.""" if self._run_context is None: raise RuntimeError("Runtime is not started") await self._run_context.stop_when_idle() self._run_context = None + self._message_queue = Queue() async def stop_when(self, condition: Callable[[], bool]) -> None: - """Stop the runtime message processing loop when the condition is met.""" + """Stop the runtime message processing loop when the condition is met. + + .. caution:: + + This method is not recommended to be used, and is here for legacy + reasons. It will spawn a busy loop to continually check the + condition. It is much more efficient to call `stop_when_idle` or + `stop` instead. If you need to stop the runtime based on a + condition, consider using a background task and asyncio.Event to + signal when the condition is met and the background task should call + stop. + + """ if self._run_context is None: raise RuntimeError("Runtime is not started") await self._run_context.stop_when(condition) self._run_context = None + self._message_queue = Queue() async def agent_metadata(self, agent: AgentId) -> AgentMetadata: return (await self._get_agent(agent)).metadata diff --git a/python/packages/autogen-core/tests/test_cancellation.py b/python/packages/autogen-core/tests/test_cancellation.py index 34a5d7f962c4..9da513f934fe 100644 --- a/python/packages/autogen-core/tests/test_cancellation.py +++ b/python/packages/autogen-core/tests/test_cancellation.py @@ -71,10 +71,10 @@ async def test_cancellation_with_token() -> None: response = asyncio.create_task(runtime.send_message(MessageType(), recipient=agent_id, cancellation_token=token)) assert not response.done() - while len(runtime.unprocessed_messages) == 0: + while runtime.unprocessed_messages_count == 0: await asyncio.sleep(0.01) - await runtime.process_next() + await runtime._process_next() # type: ignore token.cancel() @@ -104,10 +104,10 @@ async def test_nested_cancellation_only_outer_called() -> None: response = asyncio.create_task(runtime.send_message(MessageType(), nested_id, cancellation_token=token)) assert not response.done() - while len(runtime.unprocessed_messages) == 0: + while runtime.unprocessed_messages_count == 0: await asyncio.sleep(0.01) - await runtime.process_next() + await runtime._process_next() # type: ignore token.cancel() with pytest.raises(asyncio.CancelledError): @@ -140,12 +140,12 @@ async def test_nested_cancellation_inner_called() -> None: response = asyncio.create_task(runtime.send_message(MessageType(), nested_id, cancellation_token=token)) assert not response.done() - while len(runtime.unprocessed_messages) == 0: + while runtime.unprocessed_messages_count == 0: await asyncio.sleep(0.01) - await runtime.process_next() + await runtime._process_next() # type: ignore # allow the inner agent to process - await runtime.process_next() + await runtime._process_next() # type: ignore token.cancel() with pytest.raises(asyncio.CancelledError):