Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 18 additions & 7 deletions src/strands/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,13 @@
from pydantic import BaseModel

from ..event_loop.event_loop import event_loop_cycle, run_tool
from ..experimental.hooks import AgentInitializedEvent, EndRequestEvent, HookRegistry, StartRequestEvent
from ..experimental.hooks import (
AgentInitializedEvent,
EndRequestEvent,
HookRegistry,
MessageAddedEvent,
StartRequestEvent,
)
from ..handlers.callback_handler import PrintingCallbackHandler, null_callback_handler
from ..models.bedrock import BedrockModel
from ..telemetry.metrics import EventLoopMetrics
Expand Down Expand Up @@ -424,7 +430,7 @@ async def structured_output_async(self, output_model: Type[T], prompt: Optional[

# add the prompt as the last message
if prompt:
self.messages.append({"role": "user", "content": [{"text": prompt}]})
self._append_message({"role": "user", "content": [{"text": prompt}]})

events = self.model.structured_output(output_model, self.messages)
async for event in events:
Expand Down Expand Up @@ -505,7 +511,7 @@ async def _run_loop(self, message: Message, kwargs: dict[str, Any]) -> AsyncGene
try:
yield {"callback": {"init_event_loop": True, **kwargs}}

self.messages.append(message)
self._append_message(message)

# Execute the event loop cycle with retry logic for context limits
events = self._execute_event_loop_cycle(kwargs)
Expand Down Expand Up @@ -595,10 +601,10 @@ def _record_tool_execution(
}

# Add to message history
messages.append(user_msg)
messages.append(tool_use_msg)
messages.append(tool_result_msg)
messages.append(assistant_msg)
self._append_message(user_msg)
self._append_message(tool_use_msg)
self._append_message(tool_result_msg)
self._append_message(assistant_msg)

def _start_agent_trace_span(self, message: Message) -> None:
"""Starts a trace span for the agent.
Expand Down Expand Up @@ -640,3 +646,8 @@ def _end_agent_trace_span(
trace_attributes["error"] = error

self.tracer.end_agent_span(**trace_attributes)

def _append_message(self, message: Message) -> None:
"""Appends a message to the agent's list of messages and invokes the callbacks for the MessageCreatedEvent."""
self.messages.append(message)
self._hooks.invoke_callbacks(MessageAddedEvent(agent=self, message=message))
3 changes: 3 additions & 0 deletions src/strands/event_loop/event_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from typing import TYPE_CHECKING, Any, AsyncGenerator, cast

from ..experimental.hooks import AfterToolInvocationEvent, BeforeToolInvocationEvent
from ..experimental.hooks.events import MessageAddedEvent
from ..experimental.hooks.registry import get_registry
from ..telemetry.metrics import Trace
from ..telemetry.tracer import get_tracer
Expand Down Expand Up @@ -166,6 +167,7 @@ async def event_loop_cycle(agent: "Agent", kwargs: dict[str, Any]) -> AsyncGener

# Add the response message to the conversation
agent.messages.append(message)
get_registry(agent).invoke_callbacks(MessageAddedEvent(agent=agent, message=message))
yield {"callback": {"message": message}}

# Update metrics
Expand Down Expand Up @@ -431,6 +433,7 @@ def tool_handler(tool_use: ToolUse) -> ToolGenerator:
}

agent.messages.append(tool_result_message)
get_registry(agent).invoke_callbacks(MessageAddedEvent(agent=agent, message=tool_result_message))
yield {"callback": {"message": tool_result_message}}

if cycle_span:
Expand Down
5 changes: 4 additions & 1 deletion src/strands/experimental/hooks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,18 +34,21 @@ def log_end(self, event: EndRequestEvent) -> None:
AgentInitializedEvent,
BeforeToolInvocationEvent,
EndRequestEvent,
MessageAddedEvent,
StartRequestEvent,
)
from .registry import HookCallback, HookEvent, HookProvider, HookRegistry
from .registry import HookCallback, HookEvent, HookProvider, HookRegistry, get_registry

__all__ = [
"AgentInitializedEvent",
"StartRequestEvent",
"EndRequestEvent",
"BeforeToolInvocationEvent",
"AfterToolInvocationEvent",
"MessageAddedEvent",
"HookEvent",
"HookProvider",
"HookCallback",
"HookRegistry",
"get_registry",
]
20 changes: 20 additions & 0 deletions src/strands/experimental/hooks/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from dataclasses import dataclass
from typing import Any, Optional

from ...types.content import Message
from ...types.tools import AgentTool, ToolResult, ToolUse
from .registry import HookEvent

Expand Down Expand Up @@ -118,3 +119,22 @@ def _can_write(self, name: str) -> bool:
def should_reverse_callbacks(self) -> bool:
"""True to invoke callbacks in reverse order."""
return True


@dataclass
class MessageAddedEvent(HookEvent):
"""Event triggered when a message is added to the agent's conversation.

This event is fired whenever the agent adds a new message to its internal
message history, including user messages, assistant responses, and tool
results. Hook providers can use this event for logging, monitoring, or
implementing custom message processing logic.

Note: This event is only triggered for messages added by the framework
itself, not for messages manually added by tools or external code.

Attributes:
message: The message that was added to the conversation history.
"""

message: Message
92 changes: 85 additions & 7 deletions tests/strands/agent/test_agent_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,27 @@
AgentInitializedEvent,
BeforeToolInvocationEvent,
EndRequestEvent,
MessageAddedEvent,
StartRequestEvent,
get_registry,
)
from strands.types.content import Messages
from strands.types.tools import ToolResult, ToolUse
from tests.fixtures.mock_hook_provider import MockHookProvider
from tests.fixtures.mocked_model_provider import MockedModelProvider


@pytest.fixture
def hook_provider():
return MockHookProvider(
[AgentInitializedEvent, StartRequestEvent, EndRequestEvent, AfterToolInvocationEvent, BeforeToolInvocationEvent]
[
AgentInitializedEvent,
StartRequestEvent,
EndRequestEvent,
AfterToolInvocationEvent,
BeforeToolInvocationEvent,
MessageAddedEvent,
]
)


Expand Down Expand Up @@ -63,8 +73,13 @@ def agent(
tools=[agent_tool],
)

# for now, hooks are private
agent._hooks.add_hook(hook_provider)
hooks = get_registry(agent)
hooks.add_hook(hook_provider)

def assert_message_is_last_message_added(event: MessageAddedEvent):
assert event.agent.messages[-1] == event.message

hooks.add_callback(MessageAddedEvent, assert_message_is_last_message_added)

return agent

Expand All @@ -88,15 +103,49 @@ def test_agent__init__hooks(mock_invoke_callbacks):
assert mock_invoke_callbacks.call_args == call(AgentInitializedEvent(agent=agent))


def test_agent_tool_call(agent, hook_provider, agent_tool):
agent.tool.tool_decorated(random_string="a string")

length, events = hook_provider.get_events()

tool_use: ToolUse = {"input": {"random_string": "a string"}, "name": "tool_decorated", "toolUseId": ANY}
result: ToolResult = {"content": [{"text": "gnirts a"}], "status": "success", "toolUseId": ANY}

assert length == 6

assert next(events) == BeforeToolInvocationEvent(
agent=agent, selected_tool=agent_tool, tool_use=tool_use, kwargs=ANY
)
assert next(events) == AfterToolInvocationEvent(
agent=agent,
selected_tool=agent_tool,
tool_use=tool_use,
kwargs=ANY,
result=result,
)
assert next(events) == MessageAddedEvent(agent=agent, message=agent.messages[0])
assert next(events) == MessageAddedEvent(agent=agent, message=agent.messages[1])
assert next(events) == MessageAddedEvent(agent=agent, message=agent.messages[2])
assert next(events) == MessageAddedEvent(agent=agent, message=agent.messages[3])

assert len(agent.messages) == 4


def test_agent__call__hooks(agent, hook_provider, agent_tool, tool_use):
"""Verify that the correct hook events are emitted as part of __call__."""

agent("test message")

length, events = hook_provider.get_events()

assert length == 4
assert length == 8

assert next(events) == StartRequestEvent(agent=agent)
assert next(events) == MessageAddedEvent(
agent=agent,
message=agent.messages[0],
)
assert next(events) == MessageAddedEvent(agent=agent, message=agent.messages[1])
assert next(events) == BeforeToolInvocationEvent(
agent=agent, selected_tool=agent_tool, tool_use=tool_use, kwargs=ANY
)
Expand All @@ -107,8 +156,12 @@ def test_agent__call__hooks(agent, hook_provider, agent_tool, tool_use):
kwargs=ANY,
result={"content": [{"text": "!loot a dekovni I"}], "status": "success", "toolUseId": "123"},
)
assert next(events) == MessageAddedEvent(agent=agent, message=agent.messages[2])
assert next(events) == MessageAddedEvent(agent=agent, message=agent.messages[3])
assert next(events) == EndRequestEvent(agent=agent)

assert len(agent.messages) == 4


@pytest.mark.asyncio
async def test_agent_stream_async_hooks(agent, hook_provider, agent_tool, tool_use):
Expand All @@ -123,9 +176,14 @@ async def test_agent_stream_async_hooks(agent, hook_provider, agent_tool, tool_u

length, events = hook_provider.get_events()

assert length == 4
assert length == 8

assert next(events) == StartRequestEvent(agent=agent)
assert next(events) == MessageAddedEvent(
agent=agent,
message=agent.messages[0],
)
assert next(events) == MessageAddedEvent(agent=agent, message=agent.messages[1])
assert next(events) == BeforeToolInvocationEvent(
agent=agent, selected_tool=agent_tool, tool_use=tool_use, kwargs=ANY
)
Expand All @@ -136,16 +194,28 @@ async def test_agent_stream_async_hooks(agent, hook_provider, agent_tool, tool_u
kwargs=ANY,
result={"content": [{"text": "!loot a dekovni I"}], "status": "success", "toolUseId": "123"},
)
assert next(events) == MessageAddedEvent(agent=agent, message=agent.messages[2])
assert next(events) == MessageAddedEvent(agent=agent, message=agent.messages[3])
assert next(events) == EndRequestEvent(agent=agent)

assert len(agent.messages) == 4


def test_agent_structured_output_hooks(agent, hook_provider, user, agenerator):
"""Verify that the correct hook events are emitted as part of structured_output."""

agent.model.structured_output = Mock(return_value=agenerator([{"output": user}]))
agent.structured_output(type(user), "example prompt")

assert hook_provider.events_received == [StartRequestEvent(agent=agent), EndRequestEvent(agent=agent)]
length, events = hook_provider.get_events()

assert length == 3

assert next(events) == StartRequestEvent(agent=agent)
assert next(events) == MessageAddedEvent(agent=agent, message=agent.messages[0])
assert next(events) == EndRequestEvent(agent=agent)

assert len(agent.messages) == 1


@pytest.mark.asyncio
Expand All @@ -155,4 +225,12 @@ async def test_agent_structured_async_output_hooks(agent, hook_provider, user, a
agent.model.structured_output = Mock(return_value=agenerator([{"output": user}]))
await agent.structured_output_async(type(user), "example prompt")

assert hook_provider.events_received == [StartRequestEvent(agent=agent), EndRequestEvent(agent=agent)]
length, events = hook_provider.get_events()

assert length == 3

assert next(events) == StartRequestEvent(agent=agent)
assert next(events) == MessageAddedEvent(agent=agent, message=agent.messages[0])
assert next(events) == EndRequestEvent(agent=agent)

assert len(agent.messages) == 1
18 changes: 17 additions & 1 deletion tests/strands/experimental/hooks/test_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@

import pytest

from strands.experimental.hooks.events import (
from strands.experimental.hooks import (
AfterToolInvocationEvent,
AgentInitializedEvent,
BeforeToolInvocationEvent,
EndRequestEvent,
MessageAddedEvent,
StartRequestEvent,
)
from strands.types.tools import ToolResult, ToolUse
Expand Down Expand Up @@ -49,6 +50,11 @@ def start_request_event(agent):
return StartRequestEvent(agent=agent)


@pytest.fixture
def messaged_added_event(agent):
return MessageAddedEvent(agent=agent, message=Mock())


@pytest.fixture
def end_request_event(agent):
return EndRequestEvent(agent=agent)
Expand Down Expand Up @@ -78,6 +84,7 @@ def after_tool_event(agent, tool, tool_use, tool_kwargs, tool_result):
def test_event_should_reverse_callbacks(
initialized_event,
start_request_event,
messaged_added_event,
end_request_event,
before_tool_event,
after_tool_event,
Expand All @@ -86,13 +93,22 @@ def test_event_should_reverse_callbacks(

assert initialized_event.should_reverse_callbacks == False # noqa: E712

assert messaged_added_event.should_reverse_callbacks == False # noqa: E712

assert start_request_event.should_reverse_callbacks == False # noqa: E712
assert end_request_event.should_reverse_callbacks == True # noqa: E712

assert before_tool_event.should_reverse_callbacks == False # noqa: E712
assert after_tool_event.should_reverse_callbacks == True # noqa: E712


def test_message_added_event_cannot_write_properties(messaged_added_event):
with pytest.raises(AttributeError, match="Property agent is not writable"):
messaged_added_event.agent = Mock()
with pytest.raises(AttributeError, match="Property message is not writable"):
messaged_added_event.message = {}


def test_before_tool_invocation_event_can_write_properties(before_tool_event):
new_tool_use = ToolUse(name="new_tool", toolUseId="456", input={})
before_tool_event.selected_tool = None # Should not raise
Expand Down