Skip to content

Commit df8936c

Browse files
committed
feat(multi-agent): introduce Graph orchestrator
1 parent f20a405 commit df8936c

File tree

11 files changed

+1284
-34
lines changed

11 files changed

+1284
-34
lines changed

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,5 @@ __pycache__*
88
.ruff_cache
99
*.bak
1010
.vscode
11-
dist
11+
dist
12+
repl_state

pyproject.toml

Lines changed: 50 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -89,13 +89,57 @@ a2a = [
8989
"fastapi>=0.115.12",
9090
"starlette>=0.46.2",
9191
]
92+
all = [
93+
# anthropic
94+
"anthropic>=0.21.0,<1.0.0",
95+
96+
# dev
97+
"commitizen>=4.4.0,<5.0.0",
98+
"hatch>=1.0.0,<2.0.0",
99+
"moto>=5.1.0,<6.0.0",
100+
"mypy>=1.15.0,<2.0.0",
101+
"pre-commit>=3.2.0,<4.2.0",
102+
"pytest>=8.0.0,<9.0.0",
103+
"pytest-asyncio>=0.26.0,<0.27.0",
104+
"ruff>=0.4.4,<0.5.0",
105+
106+
# docs
107+
"sphinx>=5.0.0,<6.0.0",
108+
"sphinx-rtd-theme>=1.0.0,<2.0.0",
109+
"sphinx-autodoc-typehints>=1.12.0,<2.0.0",
110+
111+
# litellm
112+
"litellm>=1.72.6,<1.73.0",
113+
114+
# llama
115+
"llama-api-client>=0.1.0,<1.0.0",
116+
117+
# mistral
118+
"mistralai>=1.8.2",
119+
120+
# ollama
121+
"ollama>=0.4.8,<1.0.0",
122+
123+
# openai
124+
"openai>=1.68.0,<2.0.0",
125+
126+
# otel
127+
"opentelemetry-exporter-otlp-proto-http>=1.30.0,<2.0.0",
128+
129+
# a2a
130+
"a2a-sdk>=0.2.6",
131+
"uvicorn>=0.34.2",
132+
"httpx>=0.28.1",
133+
"fastapi>=0.115.12",
134+
"starlette>=0.46.2",
135+
]
92136

93137
[tool.hatch.version]
94138
# Tells Hatch to use your version control system (git) to determine the version.
95139
source = "vcs"
96140

97141
[tool.hatch.envs.hatch-static-analysis]
98-
features = ["anthropic", "litellm", "llamaapi", "ollama", "openai", "otel","mistral"]
142+
features = ["anthropic", "litellm", "llamaapi", "ollama", "openai", "otel", "mistral", "a2a"]
99143
dependencies = [
100144
"mypy>=1.15.0,<2.0.0",
101145
"ruff>=0.11.6,<0.12.0",
@@ -111,15 +155,14 @@ format-fix = [
111155
]
112156
lint-check = [
113157
"ruff check",
114-
# excluding due to A2A and OTEL http exporter dependency conflict
115-
"mypy -p src --exclude src/strands/multiagent"
158+
"mypy -p src"
116159
]
117160
lint-fix = [
118161
"ruff check --fix"
119162
]
120163

121164
[tool.hatch.envs.hatch-test]
122-
features = ["anthropic", "litellm", "llamaapi", "ollama", "openai", "otel","mistral"]
165+
features = ["anthropic", "litellm", "llamaapi", "ollama", "openai", "otel", "mistral", "a2a"]
123166
extra-dependencies = [
124167
"moto>=5.1.0,<6.0.0",
125168
"pytest>=8.0.0,<9.0.0",
@@ -135,35 +178,17 @@ extra-args = [
135178

136179
[tool.hatch.envs.dev]
137180
dev-mode = true
138-
features = ["dev", "docs", "anthropic", "litellm", "llamaapi", "ollama", "otel","mistral"]
139-
140-
[tool.hatch.envs.a2a]
141-
dev-mode = true
142-
features = ["dev", "docs", "anthropic", "litellm", "llamaapi", "ollama", "a2a"]
143-
144-
[tool.hatch.envs.a2a.scripts]
145-
run = [
146-
"pytest{env:HATCH_TEST_ARGS:} tests/multiagent/a2a {args}"
147-
]
148-
run-cov = [
149-
"pytest{env:HATCH_TEST_ARGS:} tests/multiagent/a2a --cov --cov-config=pyproject.toml {args}"
150-
]
151-
lint-check = [
152-
"ruff check",
153-
"mypy -p src/strands/multiagent/a2a"
154-
]
181+
features = ["dev", "docs", "anthropic", "litellm", "llamaapi", "ollama", "otel", "mistral"]
155182

156183
[[tool.hatch.envs.hatch-test.matrix]]
157184
python = ["3.13", "3.12", "3.11", "3.10"]
158185

159186
[tool.hatch.envs.hatch-test.scripts]
160187
run = [
161-
# excluding due to A2A and OTEL http exporter dependency conflict
162-
"pytest{env:HATCH_TEST_ARGS:} {args} --ignore=tests/multiagent/a2a"
188+
"pytest{env:HATCH_TEST_ARGS:} {args}"
163189
]
164190
run-cov = [
165-
# excluding due to A2A and OTEL http exporter dependency conflict
166-
"pytest{env:HATCH_TEST_ARGS:} --cov --cov-config=pyproject.toml {args} --ignore=tests/multiagent/a2a"
191+
"pytest{env:HATCH_TEST_ARGS:} --cov --cov-config=pyproject.toml {args}"
167192
]
168193

169194
cov-combine = []
@@ -198,10 +223,6 @@ prepare = [
198223
"hatch run test-lint",
199224
"hatch test --all"
200225
]
201-
test-a2a = [
202-
# required to run manually due to A2A and OTEL http exporter dependency conflict
203-
"hatch -e a2a run run {args}"
204-
]
205226

206227
[tool.mypy]
207228
python_version = "3.10"

src/strands/agent/agent.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import random
1616
from concurrent.futures import ThreadPoolExecutor
1717
from typing import Any, AsyncIterator, Callable, Generator, List, Mapping, Optional, Type, TypeVar, Union, cast
18+
from uuid import uuid4
1819

1920
from opentelemetry import trace
2021
from pydantic import BaseModel
@@ -191,6 +192,7 @@ def __init__(
191192
load_tools_from_directory: bool = True,
192193
trace_attributes: Optional[Mapping[str, AttributeValue]] = None,
193194
*,
195+
agent_id: Optional[str] = None,
194196
name: Optional[str] = None,
195197
description: Optional[str] = None,
196198
state: Optional[Union[AgentState, dict]] = None,
@@ -226,6 +228,8 @@ def __init__(
226228
load_tools_from_directory: Whether to load and automatically reload tools in the `./tools/` directory.
227229
Defaults to True.
228230
trace_attributes: Custom trace attributes to apply to the agent's trace span.
231+
agent_id: Optional ID for the agent, useful for multi-agent scenarios.
232+
If None, a UUID is generated.
229233
name: name of the Agent
230234
Defaults to None.
231235
description: description of what the Agent does
@@ -240,6 +244,9 @@ def __init__(
240244
self.messages = messages if messages is not None else []
241245

242246
self.system_prompt = system_prompt
247+
self.agent_id = agent_id or str(uuid4())
248+
self.name = name
249+
self.description = description
243250

244251
# If not provided, create a new PrintingCallbackHandler instance
245252
# If explicitly set to None, use null_callback_handler
@@ -305,8 +312,6 @@ def __init__(
305312
self.state = AgentState()
306313

307314
self.tool_caller = Agent.ToolCaller(self)
308-
self.name = name
309-
self.description = description
310315

311316
@property
312317
def tool(self) -> ToolCaller:

src/strands/agent/conversation_manager/sliding_window_conversation_manager.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def apply_management(self, agent: "Agent") -> None:
7575

7676
if len(messages) <= self.window_size:
7777
logger.debug(
78-
"window_size=<%s>, message_count=<%s> | skipping context reduction", len(messages), self.window_size
78+
"message_count=<%s>, window_size=<%s> | skipping context reduction", len(messages), self.window_size
7979
)
8080
return
8181
self.reduce_context(agent)

src/strands/multiagent/__init__.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,5 +9,13 @@
99
"""
1010

1111
from . import a2a
12+
from .base import MultiAgentBase, MultiAgentResult
13+
from .graph import GraphBuilder, GraphResult
1214

13-
__all__ = ["a2a"]
15+
__all__ = [
16+
"a2a",
17+
"GraphBuilder",
18+
"GraphResult",
19+
"MultiAgentBase",
20+
"MultiAgentResult",
21+
]

src/strands/multiagent/base.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
"""Multi-Agent Base Class.
2+
3+
Provides minimal foundation for multi-agent patterns (Swarm, Graph).
4+
"""
5+
6+
import logging
7+
from abc import ABC, abstractmethod
8+
from dataclasses import dataclass, field
9+
from typing import Any, Dict, List, Union
10+
11+
from ..agent import AgentResult
12+
from ..types.event_loop import Metrics, Usage
13+
14+
logger = logging.getLogger(__name__)
15+
16+
17+
@dataclass
18+
class NodeResult:
19+
"""Unified result from node execution - handles both Agent and nested MultiAgentBase results."""
20+
21+
# Core result data - single AgentResult or nested MultiAgentResult
22+
results: Union[AgentResult, "MultiAgentResult"]
23+
24+
# Execution metadata
25+
execution_time: float = 0.0
26+
status: Any = None
27+
28+
# Accumulated metrics from this node and all children
29+
accumulated_usage: Usage = field(default_factory=lambda: Usage(inputTokens=0, outputTokens=0, totalTokens=0))
30+
accumulated_metrics: Metrics = field(default_factory=lambda: Metrics(latencyMs=0))
31+
execution_count: int = 0
32+
33+
def get_agent_results(self) -> List[AgentResult]:
34+
"""Get all AgentResult objects from this node, flattened if nested."""
35+
if isinstance(self.results, AgentResult):
36+
return [self.results]
37+
else:
38+
# Flatten nested results from MultiAgentResult
39+
flattened = []
40+
for nested_node_result in self.results.results.values():
41+
flattened.extend(nested_node_result.get_agent_results())
42+
return flattened
43+
44+
45+
@dataclass
46+
class MultiAgentResult:
47+
"""Result from multi-agent execution with accumulated metrics."""
48+
49+
results: Dict[str, NodeResult]
50+
accumulated_usage: Usage = field(default_factory=lambda: Usage(inputTokens=0, outputTokens=0, totalTokens=0))
51+
accumulated_metrics: Metrics = field(default_factory=lambda: Metrics(latencyMs=0))
52+
execution_count: int = 0
53+
execution_time: float = 0.0
54+
55+
56+
class MultiAgentBase(ABC):
57+
"""Base class for multi-agent helpers.
58+
59+
This class integrates with existing Strands Agent instances and provides
60+
multi-agent orchestration capabilities.
61+
"""
62+
63+
@abstractmethod
64+
# TODO: for task - multi-modal input (Message), list of messages
65+
async def execute(self, task: str) -> MultiAgentResult:
66+
"""Execute task."""
67+
raise NotImplementedError("execute not implemented")
68+
69+
@abstractmethod
70+
# TODO: for task - multi-modal input (Message), list of messages
71+
async def resume(self, task: str, state: Any) -> MultiAgentResult:
72+
"""Resume task from previous state."""
73+
raise NotImplementedError("resume not implemented")

0 commit comments

Comments
 (0)