Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions dapr_agents/agents/durable.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,7 @@
)
from dapr_agents.types.workflow import DaprWorkflowStatus
from dapr_agents.tool.utils.serialization import serialize_tool_result
from dapr_agents.workflow.decorators.routers import message_router
from dapr_agents.workflow.runners.agent import workflow_entry
from dapr_agents.workflow.decorators import message_router, workflow_entry
from dapr_agents.workflow.utils.grpc import apply_grpc_options
from dapr_agents.workflow.utils.pubsub import broadcast_message, send_message_to_agent

Expand Down
2 changes: 1 addition & 1 deletion dapr_agents/agents/orchestrators/llm/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
TriggerAction,
)
from dapr_agents.workflow.decorators.routers import message_router
from dapr_agents.workflow.runners.agent import workflow_entry
from dapr_agents.workflow.decorators import workflow_entry
from dapr_agents.agents.orchestrators.llm.prompts import (
NEXT_STEP_PROMPT,
PROGRESS_CHECK_PROMPT,
Expand Down
2 changes: 1 addition & 1 deletion dapr_agents/agents/orchestrators/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
TriggerAction,
)
from dapr_agents.workflow.decorators.routers import message_router
from dapr_agents.workflow.runners.agent import workflow_entry
from dapr_agents.workflow.decorators import workflow_entry
from dapr_agents.workflow.utils.pubsub import broadcast_message, send_message_to_agent

logger = logging.getLogger(__name__)
Expand Down
2 changes: 1 addition & 1 deletion dapr_agents/agents/orchestrators/roundrobin.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
TriggerAction,
)
from dapr_agents.workflow.decorators.routers import message_router
from dapr_agents.workflow.runners.agent import workflow_entry
from dapr_agents.workflow.decorators import workflow_entry
from dapr_agents.workflow.utils.pubsub import broadcast_message, send_message_to_agent

logger = logging.getLogger(__name__)
Expand Down
9 changes: 8 additions & 1 deletion dapr_agents/workflow/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,15 @@
from .decorators import message_router, http_router, llm_activity, agent_activity
from .decorators import (
message_router,
http_router,
llm_activity,
agent_activity,
workflow_entry,
)

__all__ = [
"message_router",
"http_router",
"llm_activity",
"agent_activity",
"workflow_entry",
]
3 changes: 2 additions & 1 deletion dapr_agents/workflow/decorators/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from .routers import message_router, http_router
from .activities import llm_activity, agent_activity
from .activities import llm_activity, agent_activity, workflow_entry

__all__ = [
"message_router",
"http_router",
"llm_activity",
"agent_activity",
"workflow_entry",
]
25 changes: 24 additions & 1 deletion dapr_agents/workflow/decorators/activities.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import functools
import inspect
import logging
from typing import Any, Callable, Literal, Optional
from typing import Any, Callable, Literal, Optional, TypeVar

from dapr.ext.workflow import WorkflowActivityContext # type: ignore

Expand All @@ -23,6 +23,29 @@

logger = logging.getLogger(__name__)

R = TypeVar("R")


def workflow_entry(func: Callable[..., R]) -> Callable[..., R]:
"""
Mark a method/function as the workflow entrypoint for an Agent.

This decorator does not wrap the function; it simply annotates the callable
with `_is_workflow_entry = True` so AgentRunner can discover it on the agent
instance via reflection.

Usage:
class MyAgent:
@workflow_entry
def my_workflow(self, ctx: DaprWorkflowContext, wf_input: dict) -> str:
...

Returns:
The same callable (unmodified), with an identifying attribute.
"""
setattr(func, "_is_workflow_entry", True) # type: ignore[attr-defined]
return func


def llm_activity(
*,
Expand Down
126 changes: 93 additions & 33 deletions dapr_agents/workflow/runners/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
import asyncio
import concurrent.futures
import logging
import threading
from typing import Any, Callable, Dict, Literal, Optional, TypeVar, Union
from threading import Lock, Thread
from typing import Any, Callable, Dict, Literal, Optional, TypeVar, Union, List

from fastapi import Body, FastAPI, HTTPException

from dapr_agents.agents.components import AgentComponents
from dapr_agents.agents.durable import DurableAgent
from dapr_agents.types.workflow import PubSubRouteSpec
from dapr_agents.workflow.runners.base import WorkflowRunner
from dapr_agents.workflow.utils.core import get_decorated_methods
Expand All @@ -21,27 +23,6 @@
R = TypeVar("R")


def workflow_entry(func: Callable[..., R]) -> Callable[..., R]:
"""
Mark a method/function as the workflow entrypoint for an Agent.

This decorator does not wrap the function; it simply annotates the callable
with `_is_workflow_entry = True` so AgentRunner can discover it on the agent
instance via reflection.

Usage:
class MyAgent:
@workflow_entry
def my_workflow(self, ctx: DaprWorkflowContext, wf_input: dict) -> str:
...

Returns:
The same callable (unmodified), with an identifying attribute.
"""
setattr(func, "_is_workflow_entry", True) # type: ignore[attr-defined]
return func


class AgentRunner(WorkflowRunner):
"""
Runner specialized for Agent classes.
Expand Down Expand Up @@ -73,9 +54,15 @@ def __init__(
)
self._default_http_paths: set[str] = set()

# In-memory store of managed agents - used for handling shutdown
self._managed_agents: List[
AgentComponents
] = [] # AgentComponents is the lowest common denominator between orchestrators and agents.
self._lock: Lock = Lock()

async def run(
self,
agent: Any,
agent: AgentComponents,
payload: Optional[Union[str, Dict[str, Any]]] = None,
*,
instance_id: Optional[str] = None,
Expand Down Expand Up @@ -112,6 +99,14 @@ async def run(
wait,
timeout_in_seconds,
)
try:
agent.start()
except RuntimeError:
# The agent is already started
pass
with self._lock:
if agent not in self._managed_agents:
self._managed_agents.append(agent)

entry = self.discover_entry(agent)
logger.debug("[%s] Discovered workflow entry: %s", self._name, entry.__name__)
Expand All @@ -128,7 +123,7 @@ async def run(

def run_sync(
self,
agent: Any,
agent: AgentComponents,
payload: Optional[Union[str, Dict[str, Any]]] = None,
*,
instance_id: Optional[str] = None,
Expand Down Expand Up @@ -223,13 +218,13 @@ def _runner() -> None:
finally:
loop.close()

t = threading.Thread(target=_runner, daemon=True)
t = Thread(target=_runner, daemon=True)
t.start()
return fut.result()

def register_routes(
self,
agent: Any,
agent: AgentComponents,
*,
fastapi_app: Optional[FastAPI] = None,
delivery_mode: Literal["sync", "async"] = "sync",
Expand All @@ -252,6 +247,16 @@ def register_routes(
fetch_payloads: Whether to fetch input/output payloads for awaited workflows.
log_outcome: Whether to log the final outcome of awaited workflows.
"""

try:
agent.start()
except RuntimeError:
# The agent is already started
pass
with self._lock:
if agent not in self._managed_agents:
self._managed_agents.append(agent)

self._wire_pubsub_routes(
agent=agent,
delivery_mode=delivery_mode,
Expand All @@ -265,7 +270,9 @@ def register_routes(
if fastapi_app is not None:
self._wire_http_routes(agent=agent, fastapi_app=fastapi_app)

def _build_pubsub_specs(self, agent: Any, config: Any) -> list[PubSubRouteSpec]:
def _build_pubsub_specs(
self, agent: AgentComponents, config: Any
) -> list[PubSubRouteSpec]:
handlers = get_decorated_methods(agent, "_is_message_handler")
if not handlers:
return []
Expand Down Expand Up @@ -300,7 +307,7 @@ def _build_pubsub_specs(self, agent: Any, config: Any) -> list[PubSubRouteSpec]:
def _wire_pubsub_routes(
self,
*,
agent: Any,
agent: AgentComponents,
delivery_mode: Literal["sync", "async"],
queue_maxsize: int,
await_result: bool,
Expand Down Expand Up @@ -339,7 +346,9 @@ def _wire_pubsub_routes(
self._pubsub_closers.extend(closers)
self._wired_pubsub = True

def _wire_http_routes(self, *, agent: Any, fastapi_app: Optional[FastAPI]) -> None:
def _wire_http_routes(
self, *, agent: AgentComponents, fastapi_app: Optional[FastAPI]
) -> None:
if fastapi_app is None or self._wired_http:
return

Expand All @@ -352,7 +361,7 @@ def _wire_http_routes(self, *, agent: Any, fastapi_app: Optional[FastAPI]) -> No

def subscribe(
self,
agent: Any,
agent: AgentComponents,
*,
delivery_mode: Literal["sync", "async"] = "sync",
queue_maxsize: int = 1024,
Expand All @@ -376,6 +385,17 @@ def subscribe(
Returns:
The runner (to allow fluent chaining).
"""

try:
agent.start()
except RuntimeError:
# The agent is already started
pass

with self._lock:
if agent not in self._managed_agents:
self._managed_agents.append(agent)

self._wire_pubsub_routes(
agent=agent,
delivery_mode=delivery_mode,
Expand All @@ -389,7 +409,7 @@ def subscribe(

def serve(
self,
agent: Any,
agent: AgentComponents,
*,
app: Optional[FastAPI] = None,
host: str = "0.0.0.0",
Expand Down Expand Up @@ -422,8 +442,19 @@ def serve(
Returns:
The FastAPI application with the workflow routes.
"""

fastapi_app = app or FastAPI(title="Dapr Agent Service", version="1.0.0")

try:
agent.start()
except RuntimeError:
# The agent is already started
pass

with self._lock:
if agent not in self._managed_agents:
self._managed_agents.append(agent)

self.subscribe(
agent,
delivery_mode=delivery_mode,
Expand Down Expand Up @@ -476,7 +507,7 @@ def _mount_service_routes(
self,
*,
fastapi_app: FastAPI,
agent: Any,
agent: AgentComponents,
entry_path: str,
status_path: str,
workflow_component: str,
Expand Down Expand Up @@ -550,3 +581,32 @@ async def _get_status(instance_id: str) -> dict:
tags=["workflow"],
)
logger.info("Mounted default workflow status endpoint at %s", status_path)

def shutdown(self, agent: Optional[AgentComponents]) -> None:
"""
Unwire subscriptions and close owned clients.

Args:
agent: Durable agent instance.

Returns:
None
"""

if agent:
# We need to shutdown a single agent
# First verify we're managing it
with self._lock:
if agent in self._managed_agents:
agent.stop() # This is safe as they'll return None if not started
self._managed_agents.remove(agent)
return
try:
self.unwire_pubsub()
finally:
with self._lock:
agents = list(self._managed_agents)
for ag in agents:
ag.stop()
self._close_dapr_client()
self._close_wf_client()
4 changes: 1 addition & 3 deletions quickstarts/01-hello-world/03_durable_agent_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,6 @@ async def main() -> None:
memory=memory,
state=state,
)
travel_planner.start()

runner = AgentRunner()
prompt = "I want to find flights to Paris"
Expand All @@ -91,8 +90,7 @@ async def main() -> None:
print(f"Error running workflow: {e}")
raise
finally:
travel_planner.stop()
runner.shutdown()
runner.shutdown(travel_planner)


if __name__ == "__main__":
Expand Down
4 changes: 1 addition & 3 deletions quickstarts/01-hello-world/03_durable_agent_serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,14 +76,12 @@ def main() -> None:
memory=memory,
state=state,
)
agent.start()

runner = AgentRunner()
try:
runner.serve(agent, port=8001)
finally:
runner.shutdown()
agent.stop()
runner.shutdown(agent)


if __name__ == "__main__":
Expand Down
Loading