diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index 8ebf459f6..58f64f2c9 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -20,14 +20,15 @@ from pydantic import BaseModel from ..event_loop.event_loop import event_loop_cycle, run_tool -from ..experimental.hooks import ( +from ..handlers.callback_handler import PrintingCallbackHandler, null_callback_handler +from ..hooks import ( AfterInvocationEvent, AgentInitializedEvent, BeforeInvocationEvent, + HookProvider, HookRegistry, MessageAddedEvent, ) -from ..handlers.callback_handler import PrintingCallbackHandler, null_callback_handler from ..models.bedrock import BedrockModel from ..telemetry.metrics import EventLoopMetrics from ..telemetry.tracer import get_tracer @@ -202,6 +203,7 @@ def __init__( name: Optional[str] = None, description: Optional[str] = None, state: Optional[Union[AgentState, dict]] = None, + hooks: Optional[list[HookProvider]] = None, ): """Initialize the Agent with the specified configuration. @@ -238,6 +240,8 @@ def __init__( Defaults to None. state: stateful information for the agent. Can be either an AgentState object, or a json serializable dict. Defaults to an empty AgentState object. + hooks: hooks to be added to the agent hook registry + Defaults to None. """ self.model = BedrockModel() if not model else BedrockModel(model_id=model) if isinstance(model, str) else model self.messages = messages if messages is not None else [] @@ -301,9 +305,11 @@ def __init__( self.name = name or _DEFAULT_AGENT_NAME self.description = description - self._hooks = HookRegistry() - # Register built-in hook providers (like ConversationManager) here - self._hooks.invoke_callbacks(AgentInitializedEvent(agent=self)) + self.hooks = HookRegistry() + if hooks: + for hook in hooks: + self.hooks.add_hook(hook) + self.hooks.invoke_callbacks(AgentInitializedEvent(agent=self)) @property def tool(self) -> ToolCaller: @@ -424,7 +430,7 @@ async def structured_output_async( Raises: ValueError: If no conversation history or prompt is provided. """ - self._hooks.invoke_callbacks(BeforeInvocationEvent(agent=self)) + self.hooks.invoke_callbacks(BeforeInvocationEvent(agent=self)) try: if not self.messages and not prompt: @@ -443,7 +449,7 @@ async def structured_output_async( return event["output"] finally: - self._hooks.invoke_callbacks(AfterInvocationEvent(agent=self)) + self.hooks.invoke_callbacks(AfterInvocationEvent(agent=self)) async def stream_async(self, prompt: Union[str, list[ContentBlock]], **kwargs: Any) -> AsyncIterator[Any]: """Process a natural language prompt and yield events as an async iterator. @@ -509,7 +515,7 @@ async def _run_loop(self, message: Message, kwargs: dict[str, Any]) -> AsyncGene Yields: Events from the event loop cycle. """ - self._hooks.invoke_callbacks(BeforeInvocationEvent(agent=self)) + self.hooks.invoke_callbacks(BeforeInvocationEvent(agent=self)) try: yield {"callback": {"init_event_loop": True, **kwargs}} @@ -523,7 +529,7 @@ async def _run_loop(self, message: Message, kwargs: dict[str, Any]) -> AsyncGene finally: self.conversation_manager.apply_management(self) - self._hooks.invoke_callbacks(AfterInvocationEvent(agent=self)) + self.hooks.invoke_callbacks(AfterInvocationEvent(agent=self)) async def _execute_event_loop_cycle(self, kwargs: dict[str, Any]) -> AsyncGenerator[dict[str, Any], None]: """Execute the event loop cycle with retry logic for context window limits. @@ -653,4 +659,4 @@ def _end_agent_trace_span( def _append_message(self, message: Message) -> None: """Appends a message to the agent's list of messages and invokes the callbacks for the MessageCreatedEvent.""" self.messages.append(message) - self._hooks.invoke_callbacks(MessageAddedEvent(agent=self, message=message)) + self.hooks.invoke_callbacks(MessageAddedEvent(agent=self, message=message)) diff --git a/src/strands/event_loop/event_loop.py b/src/strands/event_loop/event_loop.py index c5bf611f7..0ab0a2655 100644 --- a/src/strands/event_loop/event_loop.py +++ b/src/strands/event_loop/event_loop.py @@ -18,8 +18,9 @@ AfterToolInvocationEvent, BeforeModelInvocationEvent, BeforeToolInvocationEvent, +) +from ..hooks import ( MessageAddedEvent, - get_registry, ) from ..telemetry.metrics import Trace from ..telemetry.tracer import get_tracer @@ -120,7 +121,7 @@ async def event_loop_cycle(agent: "Agent", kwargs: dict[str, Any]) -> AsyncGener tool_specs = agent.tool_registry.get_all_tool_specs() - get_registry(agent).invoke_callbacks( + agent.hooks.invoke_callbacks( BeforeModelInvocationEvent( agent=agent, ) @@ -136,7 +137,7 @@ async def event_loop_cycle(agent: "Agent", kwargs: dict[str, Any]) -> AsyncGener stop_reason, message, usage, metrics = event["stop"] kwargs.setdefault("request_state", {}) - get_registry(agent).invoke_callbacks( + agent.hooks.invoke_callbacks( AfterModelInvocationEvent( agent=agent, stop_response=AfterModelInvocationEvent.ModelStopResponse( @@ -154,7 +155,7 @@ async def event_loop_cycle(agent: "Agent", kwargs: dict[str, Any]) -> AsyncGener if model_invoke_span: tracer.end_span_with_error(model_invoke_span, str(e), e) - get_registry(agent).invoke_callbacks( + agent.hooks.invoke_callbacks( AfterModelInvocationEvent( agent=agent, exception=e, @@ -188,7 +189,7 @@ async def event_loop_cycle(agent: "Agent", kwargs: dict[str, Any]) -> AsyncGener # Add the response message to the conversation agent.messages.append(message) - get_registry(agent).invoke_callbacks(MessageAddedEvent(agent=agent, message=message)) + agent.hooks.invoke_callbacks(MessageAddedEvent(agent=agent, message=message)) yield {"callback": {"message": message}} # Update metrics @@ -308,7 +309,7 @@ async def run_tool(agent: "Agent", tool_use: ToolUse, kwargs: dict[str, Any]) -> } ) - before_event = get_registry(agent).invoke_callbacks( + before_event = agent.hooks.invoke_callbacks( BeforeToolInvocationEvent( agent=agent, selected_tool=tool_func, @@ -342,7 +343,7 @@ async def run_tool(agent: "Agent", tool_use: ToolUse, kwargs: dict[str, Any]) -> "content": [{"text": f"Unknown tool: {tool_name}"}], } # for every Before event call, we need to have an AfterEvent call - after_event = get_registry(agent).invoke_callbacks( + after_event = agent.hooks.invoke_callbacks( AfterToolInvocationEvent( agent=agent, selected_tool=selected_tool, @@ -359,7 +360,7 @@ async def run_tool(agent: "Agent", tool_use: ToolUse, kwargs: dict[str, Any]) -> result = event - after_event = get_registry(agent).invoke_callbacks( + after_event = agent.hooks.invoke_callbacks( AfterToolInvocationEvent( agent=agent, selected_tool=selected_tool, @@ -377,7 +378,7 @@ async def run_tool(agent: "Agent", tool_use: ToolUse, kwargs: dict[str, Any]) -> "status": "error", "content": [{"text": f"Error: {str(e)}"}], } - after_event = get_registry(agent).invoke_callbacks( + after_event = agent.hooks.invoke_callbacks( AfterToolInvocationEvent( agent=agent, selected_tool=selected_tool, @@ -454,7 +455,7 @@ def tool_handler(tool_use: ToolUse) -> ToolGenerator: } agent.messages.append(tool_result_message) - get_registry(agent).invoke_callbacks(MessageAddedEvent(agent=agent, message=tool_result_message)) + agent.hooks.invoke_callbacks(MessageAddedEvent(agent=agent, message=tool_result_message)) yield {"callback": {"message": tool_result_message}} if cycle_span: diff --git a/src/strands/experimental/hooks/__init__.py b/src/strands/experimental/hooks/__init__.py index 87e16dc54..098d4cf0d 100644 --- a/src/strands/experimental/hooks/__init__.py +++ b/src/strands/experimental/hooks/__init__.py @@ -1,58 +1,15 @@ -"""Typed hook system for extending agent functionality. - -This module provides a composable mechanism for building objects that can hook -into specific events during the agent lifecycle. The hook system enables both -built-in SDK components and user code to react to or modify agent behavior -through strongly-typed event callbacks. - -Example Usage: - ```python - from strands.hooks import HookProvider, HookRegistry - from strands.hooks.events import StartRequestEvent, EndRequestEvent - - class LoggingHooks(HookProvider): - def register_hooks(self, registry: HookRegistry) -> None: - registry.add_callback(StartRequestEvent, self.log_start) - registry.add_callback(EndRequestEvent, self.log_end) - - def log_start(self, event: StartRequestEvent) -> None: - print(f"Request started for {event.agent.name}") - - def log_end(self, event: EndRequestEvent) -> None: - print(f"Request completed for {event.agent.name}") - - # Use with agent - agent = Agent(hooks=[LoggingHooks()]) - ``` - -This replaces the older callback_handler approach with a more composable, -type-safe system that supports multiple subscribers per event type. -""" +"""Experimental hook functionality that has not yet reached stability.""" from .events import ( - AfterInvocationEvent, AfterModelInvocationEvent, AfterToolInvocationEvent, - AgentInitializedEvent, - BeforeInvocationEvent, BeforeModelInvocationEvent, BeforeToolInvocationEvent, - MessageAddedEvent, ) -from .registry import HookCallback, HookEvent, HookProvider, HookRegistry, get_registry __all__ = [ - "AgentInitializedEvent", - "BeforeInvocationEvent", - "AfterInvocationEvent", - "BeforeModelInvocationEvent", - "AfterModelInvocationEvent", "BeforeToolInvocationEvent", "AfterToolInvocationEvent", - "MessageAddedEvent", - "HookEvent", - "HookProvider", - "HookCallback", - "HookRegistry", - "get_registry", + "BeforeModelInvocationEvent", + "AfterModelInvocationEvent", ] diff --git a/src/strands/experimental/hooks/events.py b/src/strands/experimental/hooks/events.py index ae0067320..b0501a9b3 100644 --- a/src/strands/experimental/hooks/events.py +++ b/src/strands/experimental/hooks/events.py @@ -1,4 +1,4 @@ -"""Hook events emitted as part of invoking Agents. +"""Experimental hook events emitted as part of invoking Agents. This module defines the events that are emitted as Agents run through the lifecycle of a request. """ @@ -6,62 +6,10 @@ from dataclasses import dataclass from typing import Any, Optional +from ...hooks import HookEvent from ...types.content import Message from ...types.streaming import StopReason from ...types.tools import AgentTool, ToolResult, ToolUse -from .registry import HookEvent - - -@dataclass -class AgentInitializedEvent(HookEvent): - """Event triggered when an agent has finished initialization. - - This event is fired after the agent has been fully constructed and all - built-in components have been initialized. Hook providers can use this - event to perform setup tasks that require a fully initialized agent. - """ - - pass - - -@dataclass -class BeforeInvocationEvent(HookEvent): - """Event triggered at the beginning of a new agent request. - - This event is fired before the agent begins processing a new user request, - before any model inference or tool execution occurs. Hook providers can - use this event to perform request-level setup, logging, or validation. - - This event is triggered at the beginning of the following api calls: - - Agent.__call__ - - Agent.stream_async - - Agent.structured_output - """ - - pass - - -@dataclass -class AfterInvocationEvent(HookEvent): - """Event triggered at the end of an agent request. - - This event is fired after the agent has completed processing a request, - regardless of whether it completed successfully or encountered an error. - Hook providers can use this event for cleanup, logging, or state persistence. - - Note: This event uses reverse callback ordering, meaning callbacks registered - later will be invoked first during cleanup. - - This event is triggered at the end of the following api calls: - - Agent.__call__ - - Agent.stream_async - - Agent.structured_output - """ - - @property - def should_reverse_callbacks(self) -> bool: - """True to invoke callbacks in reverse order.""" - return True @dataclass @@ -173,22 +121,3 @@ class ModelStopResponse: def should_reverse_callbacks(self) -> bool: """True to invoke callbacks in reverse order.""" return True - - -@dataclass -class MessageAddedEvent(HookEvent): - """Event triggered when a message is added to the agent's conversation. - - This event is fired whenever the agent adds a new message to its internal - message history, including user messages, assistant responses, and tool - results. Hook providers can use this event for logging, monitoring, or - implementing custom message processing logic. - - Note: This event is only triggered for messages added by the framework - itself, not for messages manually added by tools or external code. - - Attributes: - message: The message that was added to the conversation history. - """ - - message: Message diff --git a/src/strands/hooks/__init__.py b/src/strands/hooks/__init__.py new file mode 100644 index 000000000..77be9d64e --- /dev/null +++ b/src/strands/hooks/__init__.py @@ -0,0 +1,49 @@ +"""Typed hook system for extending agent functionality. + +This module provides a composable mechanism for building objects that can hook +into specific events during the agent lifecycle. The hook system enables both +built-in SDK components and user code to react to or modify agent behavior +through strongly-typed event callbacks. + +Example Usage: + ```python + from strands.hooks import HookProvider, HookRegistry + from strands.hooks.events import StartRequestEvent, EndRequestEvent + + class LoggingHooks(HookProvider): + def register_hooks(self, registry: HookRegistry) -> None: + registry.add_callback(StartRequestEvent, self.log_start) + registry.add_callback(EndRequestEvent, self.log_end) + + def log_start(self, event: StartRequestEvent) -> None: + print(f"Request started for {event.agent.name}") + + def log_end(self, event: EndRequestEvent) -> None: + print(f"Request completed for {event.agent.name}") + + # Use with agent + agent = Agent(hooks=[LoggingHooks()]) + ``` + +This replaces the older callback_handler approach with a more composable, +type-safe system that supports multiple subscribers per event type. +""" + +from .events import ( + AfterInvocationEvent, + AgentInitializedEvent, + BeforeInvocationEvent, + MessageAddedEvent, +) +from .registry import HookCallback, HookEvent, HookProvider, HookRegistry + +__all__ = [ + "AgentInitializedEvent", + "BeforeInvocationEvent", + "AfterInvocationEvent", + "MessageAddedEvent", + "HookEvent", + "HookProvider", + "HookCallback", + "HookRegistry", +] diff --git a/src/strands/hooks/events.py b/src/strands/hooks/events.py new file mode 100644 index 000000000..42509dc9f --- /dev/null +++ b/src/strands/hooks/events.py @@ -0,0 +1,80 @@ +"""Hook events emitted as part of invoking Agents. + +This module defines the events that are emitted as Agents run through the lifecycle of a request. +""" + +from dataclasses import dataclass + +from ..types.content import Message +from .registry import HookEvent + + +@dataclass +class AgentInitializedEvent(HookEvent): + """Event triggered when an agent has finished initialization. + + This event is fired after the agent has been fully constructed and all + built-in components have been initialized. Hook providers can use this + event to perform setup tasks that require a fully initialized agent. + """ + + pass + + +@dataclass +class BeforeInvocationEvent(HookEvent): + """Event triggered at the beginning of a new agent request. + + This event is fired before the agent begins processing a new user request, + before any model inference or tool execution occurs. Hook providers can + use this event to perform request-level setup, logging, or validation. + + This event is triggered at the beginning of the following api calls: + - Agent.__call__ + - Agent.stream_async + - Agent.structured_output + """ + + pass + + +@dataclass +class AfterInvocationEvent(HookEvent): + """Event triggered at the end of an agent request. + + This event is fired after the agent has completed processing a request, + regardless of whether it completed successfully or encountered an error. + Hook providers can use this event for cleanup, logging, or state persistence. + + Note: This event uses reverse callback ordering, meaning callbacks registered + later will be invoked first during cleanup. + + This event is triggered at the end of the following api calls: + - Agent.__call__ + - Agent.stream_async + - Agent.structured_output + """ + + @property + def should_reverse_callbacks(self) -> bool: + """True to invoke callbacks in reverse order.""" + return True + + +@dataclass +class MessageAddedEvent(HookEvent): + """Event triggered when a message is added to the agent's conversation. + + This event is fired whenever the agent adds a new message to its internal + message history, including user messages, assistant responses, and tool + results. Hook providers can use this event for logging, monitoring, or + implementing custom message processing logic. + + Note: This event is only triggered for messages added by the framework + itself, not for messages manually added by tools or external code. + + Attributes: + message: The message that was added to the conversation history. + """ + + message: Message diff --git a/src/strands/experimental/hooks/registry.py b/src/strands/hooks/registry.py similarity index 95% rename from src/strands/experimental/hooks/registry.py rename to src/strands/hooks/registry.py index befa6c397..eecf6c718 100644 --- a/src/strands/experimental/hooks/registry.py +++ b/src/strands/hooks/registry.py @@ -11,7 +11,7 @@ from typing import TYPE_CHECKING, Any, Generator, Generic, Protocol, Type, TypeVar if TYPE_CHECKING: - from ...agent import Agent + from ..agent import Agent @dataclass @@ -232,18 +232,3 @@ def get_callbacks_for(self, event: TEvent) -> Generator[HookCallback[TEvent], No yield from reversed(callbacks) else: yield from callbacks - - -def get_registry(agent: "Agent") -> HookRegistry: - """*Experimental*: Get the hooks registry for the provided agent. - - This function is available while hooks are in experimental preview. - - Args: - agent: The agent whose hook registry should be returned. - - Returns: - The HookRegistry for the given agent. - - """ - return agent._hooks diff --git a/src/strands/experimental/hooks/rules.md b/src/strands/hooks/rules.md similarity index 100% rename from src/strands/experimental/hooks/rules.md rename to src/strands/hooks/rules.md diff --git a/tests/fixtures/mock_hook_provider.py b/tests/fixtures/mock_hook_provider.py index 7214ac490..8d7e93253 100644 --- a/tests/fixtures/mock_hook_provider.py +++ b/tests/fixtures/mock_hook_provider.py @@ -1,6 +1,6 @@ from typing import Iterator, Tuple, Type -from strands.experimental.hooks import HookEvent, HookProvider, HookRegistry +from strands.hooks import HookEvent, HookProvider, HookRegistry class MockHookProvider(HookProvider): diff --git a/tests/strands/agent/test_agent_hooks.py b/tests/strands/agent/test_agent_hooks.py index 62fc32cb6..d5687b4a3 100644 --- a/tests/strands/agent/test_agent_hooks.py +++ b/tests/strands/agent/test_agent_hooks.py @@ -1,4 +1,4 @@ -from unittest.mock import ANY, Mock, call, patch +from unittest.mock import ANY, Mock import pytest from pydantic import BaseModel @@ -6,15 +6,16 @@ import strands from strands import Agent from strands.experimental.hooks import ( - AfterInvocationEvent, AfterModelInvocationEvent, AfterToolInvocationEvent, - AgentInitializedEvent, - BeforeInvocationEvent, BeforeModelInvocationEvent, BeforeToolInvocationEvent, +) +from strands.hooks import ( + AfterInvocationEvent, + AgentInitializedEvent, + BeforeInvocationEvent, MessageAddedEvent, - get_registry, ) from strands.types.content import Messages from strands.types.tools import ToolResult, ToolUse @@ -77,7 +78,7 @@ def agent( tools=[agent_tool], ) - hooks = get_registry(agent) + hooks = agent.hooks hooks.add_hook(hook_provider) def assert_message_is_last_message_added(event: MessageAddedEvent): @@ -102,14 +103,16 @@ class User(BaseModel): return User(name="Jane Doe", age=30) -@patch("strands.experimental.hooks.registry.HookRegistry.invoke_callbacks") -def test_agent__init__hooks(mock_invoke_callbacks): +def test_agent__init__hooks(): """Verify that the AgentInitializedEvent is emitted on Agent construction.""" - agent = Agent() + hook_provider = MockHookProvider(event_types=[AgentInitializedEvent]) + agent = Agent(hooks=[hook_provider]) + + length, events = hook_provider.get_events() + + assert length == 1 - # Verify AgentInitialized event was invoked - mock_invoke_callbacks.assert_called_once() - assert mock_invoke_callbacks.call_args == call(AgentInitializedEvent(agent=agent)) + assert next(events) == AgentInitializedEvent(agent=agent) def test_agent_tool_call(agent, hook_provider, agent_tool): diff --git a/tests/strands/event_loop/test_event_loop.py b/tests/strands/event_loop/test_event_loop.py index 57f2a28ef..a2ddeb3de 100644 --- a/tests/strands/event_loop/test_event_loop.py +++ b/tests/strands/event_loop/test_event_loop.py @@ -12,6 +12,8 @@ AfterToolInvocationEvent, BeforeModelInvocationEvent, BeforeToolInvocationEvent, +) +from strands.hooks import ( HookProvider, HookRegistry, ) @@ -133,7 +135,7 @@ def agent(model, system_prompt, messages, tool_registry, thread_pool, hook_regis mock.tool_registry = tool_registry mock.thread_pool = thread_pool mock.event_loop_metrics = EventLoopMetrics() - mock._hooks = hook_registry + mock.hooks = hook_registry return mock diff --git a/tests/strands/experimental/hooks/test_events.py b/tests/strands/experimental/hooks/test_events.py index 61ef40238..56c891666 100644 --- a/tests/strands/experimental/hooks/test_events.py +++ b/tests/strands/experimental/hooks/test_events.py @@ -2,12 +2,11 @@ import pytest -from strands.experimental.hooks import ( +from strands.experimental.hooks import AfterToolInvocationEvent, BeforeToolInvocationEvent +from strands.hooks import ( AfterInvocationEvent, - AfterToolInvocationEvent, AgentInitializedEvent, BeforeInvocationEvent, - BeforeToolInvocationEvent, MessageAddedEvent, ) from strands.types.tools import ToolResult, ToolUse diff --git a/tests/strands/experimental/hooks/test_hook_registry.py b/tests/strands/experimental/hooks/test_hook_registry.py index 0bed07add..693fc93d9 100644 --- a/tests/strands/experimental/hooks/test_hook_registry.py +++ b/tests/strands/experimental/hooks/test_hook_registry.py @@ -5,7 +5,7 @@ import pytest -from strands.experimental.hooks import HookEvent, HookProvider, HookRegistry +from strands.hooks import HookEvent, HookProvider, HookRegistry @dataclass