diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index cbe36d2f9..5afe4ff12 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -20,7 +20,13 @@ from pydantic import BaseModel from ..event_loop.event_loop import event_loop_cycle, run_tool -from ..experimental.hooks import AgentInitializedEvent, EndRequestEvent, HookRegistry, StartRequestEvent +from ..experimental.hooks import ( + AgentInitializedEvent, + EndRequestEvent, + HookRegistry, + MessageAddedEvent, + StartRequestEvent, +) from ..handlers.callback_handler import PrintingCallbackHandler, null_callback_handler from ..models.bedrock import BedrockModel from ..telemetry.metrics import EventLoopMetrics @@ -424,7 +430,7 @@ async def structured_output_async(self, output_model: Type[T], prompt: Optional[ # add the prompt as the last message if prompt: - self.messages.append({"role": "user", "content": [{"text": prompt}]}) + self._append_message({"role": "user", "content": [{"text": prompt}]}) events = self.model.structured_output(output_model, self.messages) async for event in events: @@ -505,7 +511,7 @@ async def _run_loop(self, message: Message, kwargs: dict[str, Any]) -> AsyncGene try: yield {"callback": {"init_event_loop": True, **kwargs}} - self.messages.append(message) + self._append_message(message) # Execute the event loop cycle with retry logic for context limits events = self._execute_event_loop_cycle(kwargs) @@ -595,10 +601,10 @@ def _record_tool_execution( } # Add to message history - messages.append(user_msg) - messages.append(tool_use_msg) - messages.append(tool_result_msg) - messages.append(assistant_msg) + self._append_message(user_msg) + self._append_message(tool_use_msg) + self._append_message(tool_result_msg) + self._append_message(assistant_msg) def _start_agent_trace_span(self, message: Message) -> None: """Starts a trace span for the agent. @@ -640,3 +646,8 @@ def _end_agent_trace_span( trace_attributes["error"] = error self.tracer.end_agent_span(**trace_attributes) + + def _append_message(self, message: Message) -> None: + """Appends a message to the agent's list of messages and invokes the callbacks for the MessageCreatedEvent.""" + self.messages.append(message) + self._hooks.invoke_callbacks(MessageAddedEvent(agent=self, message=message)) diff --git a/src/strands/event_loop/event_loop.py b/src/strands/event_loop/event_loop.py index 0c7cb4124..c2152e35b 100644 --- a/src/strands/event_loop/event_loop.py +++ b/src/strands/event_loop/event_loop.py @@ -14,6 +14,7 @@ from typing import TYPE_CHECKING, Any, AsyncGenerator, cast from ..experimental.hooks import AfterToolInvocationEvent, BeforeToolInvocationEvent +from ..experimental.hooks.events import MessageAddedEvent from ..experimental.hooks.registry import get_registry from ..telemetry.metrics import Trace from ..telemetry.tracer import get_tracer @@ -166,6 +167,7 @@ async def event_loop_cycle(agent: "Agent", kwargs: dict[str, Any]) -> AsyncGener # Add the response message to the conversation agent.messages.append(message) + get_registry(agent).invoke_callbacks(MessageAddedEvent(agent=agent, message=message)) yield {"callback": {"message": message}} # Update metrics @@ -431,6 +433,7 @@ def tool_handler(tool_use: ToolUse) -> ToolGenerator: } agent.messages.append(tool_result_message) + get_registry(agent).invoke_callbacks(MessageAddedEvent(agent=agent, message=tool_result_message)) yield {"callback": {"message": tool_result_message}} if cycle_span: diff --git a/src/strands/experimental/hooks/__init__.py b/src/strands/experimental/hooks/__init__.py index 61bd6ac3e..32e4be9ad 100644 --- a/src/strands/experimental/hooks/__init__.py +++ b/src/strands/experimental/hooks/__init__.py @@ -34,9 +34,10 @@ def log_end(self, event: EndRequestEvent) -> None: AgentInitializedEvent, BeforeToolInvocationEvent, EndRequestEvent, + MessageAddedEvent, StartRequestEvent, ) -from .registry import HookCallback, HookEvent, HookProvider, HookRegistry +from .registry import HookCallback, HookEvent, HookProvider, HookRegistry, get_registry __all__ = [ "AgentInitializedEvent", @@ -44,8 +45,10 @@ def log_end(self, event: EndRequestEvent) -> None: "EndRequestEvent", "BeforeToolInvocationEvent", "AfterToolInvocationEvent", + "MessageAddedEvent", "HookEvent", "HookProvider", "HookCallback", "HookRegistry", + "get_registry", ] diff --git a/src/strands/experimental/hooks/events.py b/src/strands/experimental/hooks/events.py index 559f1051d..980f084cb 100644 --- a/src/strands/experimental/hooks/events.py +++ b/src/strands/experimental/hooks/events.py @@ -6,6 +6,7 @@ from dataclasses import dataclass from typing import Any, Optional +from ...types.content import Message from ...types.tools import AgentTool, ToolResult, ToolUse from .registry import HookEvent @@ -118,3 +119,22 @@ def _can_write(self, name: str) -> bool: def should_reverse_callbacks(self) -> bool: """True to invoke callbacks in reverse order.""" return True + + +@dataclass +class MessageAddedEvent(HookEvent): + """Event triggered when a message is added to the agent's conversation. + + This event is fired whenever the agent adds a new message to its internal + message history, including user messages, assistant responses, and tool + results. Hook providers can use this event for logging, monitoring, or + implementing custom message processing logic. + + Note: This event is only triggered for messages added by the framework + itself, not for messages manually added by tools or external code. + + Attributes: + message: The message that was added to the conversation history. + """ + + message: Message diff --git a/tests/strands/agent/test_agent_hooks.py b/tests/strands/agent/test_agent_hooks.py index 22f261b15..8eb6a75b6 100644 --- a/tests/strands/agent/test_agent_hooks.py +++ b/tests/strands/agent/test_agent_hooks.py @@ -10,9 +10,12 @@ AgentInitializedEvent, BeforeToolInvocationEvent, EndRequestEvent, + MessageAddedEvent, StartRequestEvent, + get_registry, ) from strands.types.content import Messages +from strands.types.tools import ToolResult, ToolUse from tests.fixtures.mock_hook_provider import MockHookProvider from tests.fixtures.mocked_model_provider import MockedModelProvider @@ -20,7 +23,14 @@ @pytest.fixture def hook_provider(): return MockHookProvider( - [AgentInitializedEvent, StartRequestEvent, EndRequestEvent, AfterToolInvocationEvent, BeforeToolInvocationEvent] + [ + AgentInitializedEvent, + StartRequestEvent, + EndRequestEvent, + AfterToolInvocationEvent, + BeforeToolInvocationEvent, + MessageAddedEvent, + ] ) @@ -63,8 +73,13 @@ def agent( tools=[agent_tool], ) - # for now, hooks are private - agent._hooks.add_hook(hook_provider) + hooks = get_registry(agent) + hooks.add_hook(hook_provider) + + def assert_message_is_last_message_added(event: MessageAddedEvent): + assert event.agent.messages[-1] == event.message + + hooks.add_callback(MessageAddedEvent, assert_message_is_last_message_added) return agent @@ -88,6 +103,34 @@ def test_agent__init__hooks(mock_invoke_callbacks): assert mock_invoke_callbacks.call_args == call(AgentInitializedEvent(agent=agent)) +def test_agent_tool_call(agent, hook_provider, agent_tool): + agent.tool.tool_decorated(random_string="a string") + + length, events = hook_provider.get_events() + + tool_use: ToolUse = {"input": {"random_string": "a string"}, "name": "tool_decorated", "toolUseId": ANY} + result: ToolResult = {"content": [{"text": "gnirts a"}], "status": "success", "toolUseId": ANY} + + assert length == 6 + + assert next(events) == BeforeToolInvocationEvent( + agent=agent, selected_tool=agent_tool, tool_use=tool_use, kwargs=ANY + ) + assert next(events) == AfterToolInvocationEvent( + agent=agent, + selected_tool=agent_tool, + tool_use=tool_use, + kwargs=ANY, + result=result, + ) + assert next(events) == MessageAddedEvent(agent=agent, message=agent.messages[0]) + assert next(events) == MessageAddedEvent(agent=agent, message=agent.messages[1]) + assert next(events) == MessageAddedEvent(agent=agent, message=agent.messages[2]) + assert next(events) == MessageAddedEvent(agent=agent, message=agent.messages[3]) + + assert len(agent.messages) == 4 + + def test_agent__call__hooks(agent, hook_provider, agent_tool, tool_use): """Verify that the correct hook events are emitted as part of __call__.""" @@ -95,8 +138,14 @@ def test_agent__call__hooks(agent, hook_provider, agent_tool, tool_use): length, events = hook_provider.get_events() - assert length == 4 + assert length == 8 + assert next(events) == StartRequestEvent(agent=agent) + assert next(events) == MessageAddedEvent( + agent=agent, + message=agent.messages[0], + ) + assert next(events) == MessageAddedEvent(agent=agent, message=agent.messages[1]) assert next(events) == BeforeToolInvocationEvent( agent=agent, selected_tool=agent_tool, tool_use=tool_use, kwargs=ANY ) @@ -107,8 +156,12 @@ def test_agent__call__hooks(agent, hook_provider, agent_tool, tool_use): kwargs=ANY, result={"content": [{"text": "!loot a dekovni I"}], "status": "success", "toolUseId": "123"}, ) + assert next(events) == MessageAddedEvent(agent=agent, message=agent.messages[2]) + assert next(events) == MessageAddedEvent(agent=agent, message=agent.messages[3]) assert next(events) == EndRequestEvent(agent=agent) + assert len(agent.messages) == 4 + @pytest.mark.asyncio async def test_agent_stream_async_hooks(agent, hook_provider, agent_tool, tool_use): @@ -123,9 +176,14 @@ async def test_agent_stream_async_hooks(agent, hook_provider, agent_tool, tool_u length, events = hook_provider.get_events() - assert length == 4 + assert length == 8 assert next(events) == StartRequestEvent(agent=agent) + assert next(events) == MessageAddedEvent( + agent=agent, + message=agent.messages[0], + ) + assert next(events) == MessageAddedEvent(agent=agent, message=agent.messages[1]) assert next(events) == BeforeToolInvocationEvent( agent=agent, selected_tool=agent_tool, tool_use=tool_use, kwargs=ANY ) @@ -136,8 +194,12 @@ async def test_agent_stream_async_hooks(agent, hook_provider, agent_tool, tool_u kwargs=ANY, result={"content": [{"text": "!loot a dekovni I"}], "status": "success", "toolUseId": "123"}, ) + assert next(events) == MessageAddedEvent(agent=agent, message=agent.messages[2]) + assert next(events) == MessageAddedEvent(agent=agent, message=agent.messages[3]) assert next(events) == EndRequestEvent(agent=agent) + assert len(agent.messages) == 4 + def test_agent_structured_output_hooks(agent, hook_provider, user, agenerator): """Verify that the correct hook events are emitted as part of structured_output.""" @@ -145,7 +207,15 @@ def test_agent_structured_output_hooks(agent, hook_provider, user, agenerator): agent.model.structured_output = Mock(return_value=agenerator([{"output": user}])) agent.structured_output(type(user), "example prompt") - assert hook_provider.events_received == [StartRequestEvent(agent=agent), EndRequestEvent(agent=agent)] + length, events = hook_provider.get_events() + + assert length == 3 + + assert next(events) == StartRequestEvent(agent=agent) + assert next(events) == MessageAddedEvent(agent=agent, message=agent.messages[0]) + assert next(events) == EndRequestEvent(agent=agent) + + assert len(agent.messages) == 1 @pytest.mark.asyncio @@ -155,4 +225,12 @@ async def test_agent_structured_async_output_hooks(agent, hook_provider, user, a agent.model.structured_output = Mock(return_value=agenerator([{"output": user}])) await agent.structured_output_async(type(user), "example prompt") - assert hook_provider.events_received == [StartRequestEvent(agent=agent), EndRequestEvent(agent=agent)] + length, events = hook_provider.get_events() + + assert length == 3 + + assert next(events) == StartRequestEvent(agent=agent) + assert next(events) == MessageAddedEvent(agent=agent, message=agent.messages[0]) + assert next(events) == EndRequestEvent(agent=agent) + + assert len(agent.messages) == 1 diff --git a/tests/strands/experimental/hooks/test_events.py b/tests/strands/experimental/hooks/test_events.py index c9c5ecdd7..45446f215 100644 --- a/tests/strands/experimental/hooks/test_events.py +++ b/tests/strands/experimental/hooks/test_events.py @@ -2,11 +2,12 @@ import pytest -from strands.experimental.hooks.events import ( +from strands.experimental.hooks import ( AfterToolInvocationEvent, AgentInitializedEvent, BeforeToolInvocationEvent, EndRequestEvent, + MessageAddedEvent, StartRequestEvent, ) from strands.types.tools import ToolResult, ToolUse @@ -49,6 +50,11 @@ def start_request_event(agent): return StartRequestEvent(agent=agent) +@pytest.fixture +def messaged_added_event(agent): + return MessageAddedEvent(agent=agent, message=Mock()) + + @pytest.fixture def end_request_event(agent): return EndRequestEvent(agent=agent) @@ -78,6 +84,7 @@ def after_tool_event(agent, tool, tool_use, tool_kwargs, tool_result): def test_event_should_reverse_callbacks( initialized_event, start_request_event, + messaged_added_event, end_request_event, before_tool_event, after_tool_event, @@ -86,6 +93,8 @@ def test_event_should_reverse_callbacks( assert initialized_event.should_reverse_callbacks == False # noqa: E712 + assert messaged_added_event.should_reverse_callbacks == False # noqa: E712 + assert start_request_event.should_reverse_callbacks == False # noqa: E712 assert end_request_event.should_reverse_callbacks == True # noqa: E712 @@ -93,6 +102,13 @@ def test_event_should_reverse_callbacks( assert after_tool_event.should_reverse_callbacks == True # noqa: E712 +def test_message_added_event_cannot_write_properties(messaged_added_event): + with pytest.raises(AttributeError, match="Property agent is not writable"): + messaged_added_event.agent = Mock() + with pytest.raises(AttributeError, match="Property message is not writable"): + messaged_added_event.message = {} + + def test_before_tool_invocation_event_can_write_properties(before_tool_event): new_tool_use = ToolUse(name="new_tool", toolUseId="456", input={}) before_tool_event.selected_tool = None # Should not raise