Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 16 additions & 10 deletions src/strands/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,15 @@
from pydantic import BaseModel

from ..event_loop.event_loop import event_loop_cycle, run_tool
from ..experimental.hooks import (
from ..handlers.callback_handler import PrintingCallbackHandler, null_callback_handler
from ..hooks import (
AfterInvocationEvent,
AgentInitializedEvent,
BeforeInvocationEvent,
HookProvider,
HookRegistry,
MessageAddedEvent,
)
from ..handlers.callback_handler import PrintingCallbackHandler, null_callback_handler
from ..models.bedrock import BedrockModel
from ..telemetry.metrics import EventLoopMetrics
from ..telemetry.tracer import get_tracer
Expand Down Expand Up @@ -202,6 +203,7 @@ def __init__(
name: Optional[str] = None,
description: Optional[str] = None,
state: Optional[Union[AgentState, dict]] = None,
hooks: Optional[list[HookProvider]] = None,
):
"""Initialize the Agent with the specified configuration.

Expand Down Expand Up @@ -238,6 +240,8 @@ def __init__(
Defaults to None.
state: stateful information for the agent. Can be either an AgentState object, or a json serializable dict.
Defaults to an empty AgentState object.
hooks: hooks to be added to the agent hook registry
Defaults to None.
"""
self.model = BedrockModel() if not model else BedrockModel(model_id=model) if isinstance(model, str) else model
self.messages = messages if messages is not None else []
Expand Down Expand Up @@ -301,9 +305,11 @@ def __init__(
self.name = name or _DEFAULT_AGENT_NAME
self.description = description

self._hooks = HookRegistry()
# Register built-in hook providers (like ConversationManager) here
self._hooks.invoke_callbacks(AgentInitializedEvent(agent=self))
self.hooks = HookRegistry()
if hooks:
for hook in hooks:
self.hooks.add_hook(hook)
self.hooks.invoke_callbacks(AgentInitializedEvent(agent=self))

@property
def tool(self) -> ToolCaller:
Expand Down Expand Up @@ -424,7 +430,7 @@ async def structured_output_async(
Raises:
ValueError: If no conversation history or prompt is provided.
"""
self._hooks.invoke_callbacks(BeforeInvocationEvent(agent=self))
self.hooks.invoke_callbacks(BeforeInvocationEvent(agent=self))

try:
if not self.messages and not prompt:
Expand All @@ -443,7 +449,7 @@ async def structured_output_async(
return event["output"]

finally:
self._hooks.invoke_callbacks(AfterInvocationEvent(agent=self))
self.hooks.invoke_callbacks(AfterInvocationEvent(agent=self))

async def stream_async(self, prompt: Union[str, list[ContentBlock]], **kwargs: Any) -> AsyncIterator[Any]:
"""Process a natural language prompt and yield events as an async iterator.
Expand Down Expand Up @@ -509,7 +515,7 @@ async def _run_loop(self, message: Message, kwargs: dict[str, Any]) -> AsyncGene
Yields:
Events from the event loop cycle.
"""
self._hooks.invoke_callbacks(BeforeInvocationEvent(agent=self))
self.hooks.invoke_callbacks(BeforeInvocationEvent(agent=self))

try:
yield {"callback": {"init_event_loop": True, **kwargs}}
Expand All @@ -523,7 +529,7 @@ async def _run_loop(self, message: Message, kwargs: dict[str, Any]) -> AsyncGene

finally:
self.conversation_manager.apply_management(self)
self._hooks.invoke_callbacks(AfterInvocationEvent(agent=self))
self.hooks.invoke_callbacks(AfterInvocationEvent(agent=self))

async def _execute_event_loop_cycle(self, kwargs: dict[str, Any]) -> AsyncGenerator[dict[str, Any], None]:
"""Execute the event loop cycle with retry logic for context window limits.
Expand Down Expand Up @@ -653,4 +659,4 @@ def _end_agent_trace_span(
def _append_message(self, message: Message) -> None:
"""Appends a message to the agent's list of messages and invokes the callbacks for the MessageCreatedEvent."""
self.messages.append(message)
self._hooks.invoke_callbacks(MessageAddedEvent(agent=self, message=message))
self.hooks.invoke_callbacks(MessageAddedEvent(agent=self, message=message))
21 changes: 11 additions & 10 deletions src/strands/event_loop/event_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@
AfterToolInvocationEvent,
BeforeModelInvocationEvent,
BeforeToolInvocationEvent,
)
from ..hooks import (
MessageAddedEvent,
get_registry,
)
from ..telemetry.metrics import Trace
from ..telemetry.tracer import get_tracer
Expand Down Expand Up @@ -120,7 +121,7 @@ async def event_loop_cycle(agent: "Agent", kwargs: dict[str, Any]) -> AsyncGener

tool_specs = agent.tool_registry.get_all_tool_specs()

get_registry(agent).invoke_callbacks(
agent.hooks.invoke_callbacks(
BeforeModelInvocationEvent(
agent=agent,
)
Expand All @@ -136,7 +137,7 @@ async def event_loop_cycle(agent: "Agent", kwargs: dict[str, Any]) -> AsyncGener
stop_reason, message, usage, metrics = event["stop"]
kwargs.setdefault("request_state", {})

get_registry(agent).invoke_callbacks(
agent.hooks.invoke_callbacks(
AfterModelInvocationEvent(
agent=agent,
stop_response=AfterModelInvocationEvent.ModelStopResponse(
Expand All @@ -154,7 +155,7 @@ async def event_loop_cycle(agent: "Agent", kwargs: dict[str, Any]) -> AsyncGener
if model_invoke_span:
tracer.end_span_with_error(model_invoke_span, str(e), e)

get_registry(agent).invoke_callbacks(
agent.hooks.invoke_callbacks(
AfterModelInvocationEvent(
agent=agent,
exception=e,
Expand Down Expand Up @@ -188,7 +189,7 @@ async def event_loop_cycle(agent: "Agent", kwargs: dict[str, Any]) -> AsyncGener

# Add the response message to the conversation
agent.messages.append(message)
get_registry(agent).invoke_callbacks(MessageAddedEvent(agent=agent, message=message))
agent.hooks.invoke_callbacks(MessageAddedEvent(agent=agent, message=message))
yield {"callback": {"message": message}}

# Update metrics
Expand Down Expand Up @@ -308,7 +309,7 @@ async def run_tool(agent: "Agent", tool_use: ToolUse, kwargs: dict[str, Any]) ->
}
)

before_event = get_registry(agent).invoke_callbacks(
before_event = agent.hooks.invoke_callbacks(
BeforeToolInvocationEvent(
agent=agent,
selected_tool=tool_func,
Expand Down Expand Up @@ -342,7 +343,7 @@ async def run_tool(agent: "Agent", tool_use: ToolUse, kwargs: dict[str, Any]) ->
"content": [{"text": f"Unknown tool: {tool_name}"}],
}
# for every Before event call, we need to have an AfterEvent call
after_event = get_registry(agent).invoke_callbacks(
after_event = agent.hooks.invoke_callbacks(
AfterToolInvocationEvent(
agent=agent,
selected_tool=selected_tool,
Expand All @@ -359,7 +360,7 @@ async def run_tool(agent: "Agent", tool_use: ToolUse, kwargs: dict[str, Any]) ->

result = event

after_event = get_registry(agent).invoke_callbacks(
after_event = agent.hooks.invoke_callbacks(
AfterToolInvocationEvent(
agent=agent,
selected_tool=selected_tool,
Expand All @@ -377,7 +378,7 @@ async def run_tool(agent: "Agent", tool_use: ToolUse, kwargs: dict[str, Any]) ->
"status": "error",
"content": [{"text": f"Error: {str(e)}"}],
}
after_event = get_registry(agent).invoke_callbacks(
after_event = agent.hooks.invoke_callbacks(
AfterToolInvocationEvent(
agent=agent,
selected_tool=selected_tool,
Expand Down Expand Up @@ -454,7 +455,7 @@ def tool_handler(tool_use: ToolUse) -> ToolGenerator:
}

agent.messages.append(tool_result_message)
get_registry(agent).invoke_callbacks(MessageAddedEvent(agent=agent, message=tool_result_message))
agent.hooks.invoke_callbacks(MessageAddedEvent(agent=agent, message=tool_result_message))
yield {"callback": {"message": tool_result_message}}

if cycle_span:
Expand Down
49 changes: 3 additions & 46 deletions src/strands/experimental/hooks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,58 +1,15 @@
"""Typed hook system for extending agent functionality.

This module provides a composable mechanism for building objects that can hook
into specific events during the agent lifecycle. The hook system enables both
built-in SDK components and user code to react to or modify agent behavior
through strongly-typed event callbacks.

Example Usage:
```python
from strands.hooks import HookProvider, HookRegistry
from strands.hooks.events import StartRequestEvent, EndRequestEvent

class LoggingHooks(HookProvider):
def register_hooks(self, registry: HookRegistry) -> None:
registry.add_callback(StartRequestEvent, self.log_start)
registry.add_callback(EndRequestEvent, self.log_end)

def log_start(self, event: StartRequestEvent) -> None:
print(f"Request started for {event.agent.name}")

def log_end(self, event: EndRequestEvent) -> None:
print(f"Request completed for {event.agent.name}")

# Use with agent
agent = Agent(hooks=[LoggingHooks()])
```

This replaces the older callback_handler approach with a more composable,
type-safe system that supports multiple subscribers per event type.
"""
"""Experimental hook functionality that has not yet reached stability."""

from .events import (
AfterInvocationEvent,
AfterModelInvocationEvent,
AfterToolInvocationEvent,
AgentInitializedEvent,
BeforeInvocationEvent,
BeforeModelInvocationEvent,
BeforeToolInvocationEvent,
MessageAddedEvent,
)
from .registry import HookCallback, HookEvent, HookProvider, HookRegistry, get_registry

__all__ = [
"AgentInitializedEvent",
"BeforeInvocationEvent",
"AfterInvocationEvent",
"BeforeModelInvocationEvent",
"AfterModelInvocationEvent",
"BeforeToolInvocationEvent",
"AfterToolInvocationEvent",
"MessageAddedEvent",
"HookEvent",
"HookProvider",
"HookCallback",
"HookRegistry",
"get_registry",
"BeforeModelInvocationEvent",
"AfterModelInvocationEvent",
]
75 changes: 2 additions & 73 deletions src/strands/experimental/hooks/events.py
Original file line number Diff line number Diff line change
@@ -1,67 +1,15 @@
"""Hook events emitted as part of invoking Agents.
"""Experimental hook events emitted as part of invoking Agents.

This module defines the events that are emitted as Agents run through the lifecycle of a request.
"""

from dataclasses import dataclass
from typing import Any, Optional

from ...hooks import HookEvent
from ...types.content import Message
from ...types.streaming import StopReason
from ...types.tools import AgentTool, ToolResult, ToolUse
from .registry import HookEvent


@dataclass
class AgentInitializedEvent(HookEvent):
"""Event triggered when an agent has finished initialization.

This event is fired after the agent has been fully constructed and all
built-in components have been initialized. Hook providers can use this
event to perform setup tasks that require a fully initialized agent.
"""

pass


@dataclass
class BeforeInvocationEvent(HookEvent):
"""Event triggered at the beginning of a new agent request.

This event is fired before the agent begins processing a new user request,
before any model inference or tool execution occurs. Hook providers can
use this event to perform request-level setup, logging, or validation.

This event is triggered at the beginning of the following api calls:
- Agent.__call__
- Agent.stream_async
- Agent.structured_output
"""

pass


@dataclass
class AfterInvocationEvent(HookEvent):
"""Event triggered at the end of an agent request.

This event is fired after the agent has completed processing a request,
regardless of whether it completed successfully or encountered an error.
Hook providers can use this event for cleanup, logging, or state persistence.

Note: This event uses reverse callback ordering, meaning callbacks registered
later will be invoked first during cleanup.

This event is triggered at the end of the following api calls:
- Agent.__call__
- Agent.stream_async
- Agent.structured_output
"""

@property
def should_reverse_callbacks(self) -> bool:
"""True to invoke callbacks in reverse order."""
return True


@dataclass
Expand Down Expand Up @@ -173,22 +121,3 @@ class ModelStopResponse:
def should_reverse_callbacks(self) -> bool:
"""True to invoke callbacks in reverse order."""
return True


@dataclass
class MessageAddedEvent(HookEvent):
"""Event triggered when a message is added to the agent's conversation.

This event is fired whenever the agent adds a new message to its internal
message history, including user messages, assistant responses, and tool
results. Hook providers can use this event for logging, monitoring, or
implementing custom message processing logic.

Note: This event is only triggered for messages added by the framework
itself, not for messages manually added by tools or external code.

Attributes:
message: The message that was added to the conversation history.
"""

message: Message
49 changes: 49 additions & 0 deletions src/strands/hooks/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
"""Typed hook system for extending agent functionality.

This module provides a composable mechanism for building objects that can hook
into specific events during the agent lifecycle. The hook system enables both
built-in SDK components and user code to react to or modify agent behavior
through strongly-typed event callbacks.

Example Usage:
```python
from strands.hooks import HookProvider, HookRegistry
from strands.hooks.events import StartRequestEvent, EndRequestEvent

class LoggingHooks(HookProvider):
def register_hooks(self, registry: HookRegistry) -> None:
registry.add_callback(StartRequestEvent, self.log_start)
registry.add_callback(EndRequestEvent, self.log_end)

def log_start(self, event: StartRequestEvent) -> None:
print(f"Request started for {event.agent.name}")

def log_end(self, event: EndRequestEvent) -> None:
print(f"Request completed for {event.agent.name}")

# Use with agent
agent = Agent(hooks=[LoggingHooks()])
```

This replaces the older callback_handler approach with a more composable,
type-safe system that supports multiple subscribers per event type.
"""

from .events import (
AfterInvocationEvent,
AgentInitializedEvent,
BeforeInvocationEvent,
MessageAddedEvent,
)
from .registry import HookCallback, HookEvent, HookProvider, HookRegistry

__all__ = [
"AgentInitializedEvent",
"BeforeInvocationEvent",
"AfterInvocationEvent",
"MessageAddedEvent",
"HookEvent",
"HookProvider",
"HookCallback",
"HookRegistry",
]
Loading