From bba6a1056d094980d240e19bca6d6fecee74ea5b Mon Sep 17 00:00:00 2001 From: Mackenzie Zastrow Date: Thu, 10 Jul 2025 12:31:28 -0400 Subject: [PATCH] Now callers can subscribe and modify responses before they're sent to the model for processing. We also don't support structured output as the parameters are not applicable --- src/strands/event_loop/event_loop.py | 75 +++++++++++------- src/strands/experimental/hooks/__init__.py | 4 + src/strands/experimental/hooks/events.py | 54 +++++++++++++ src/strands/experimental/hooks/rules.md | 20 +++++ tests/strands/agent/test_agent_hooks.py | 63 ++++++++++++++- tests/strands/event_loop/test_event_loop.py | 88 +++++++++++++++------ 6 files changed, 251 insertions(+), 53 deletions(-) create mode 100644 src/strands/experimental/hooks/rules.md diff --git a/src/strands/event_loop/event_loop.py b/src/strands/event_loop/event_loop.py index c2152e35b..c5bf611f7 100644 --- a/src/strands/event_loop/event_loop.py +++ b/src/strands/event_loop/event_loop.py @@ -13,9 +13,14 @@ import uuid from typing import TYPE_CHECKING, Any, AsyncGenerator, cast -from ..experimental.hooks import AfterToolInvocationEvent, BeforeToolInvocationEvent -from ..experimental.hooks.events import MessageAddedEvent -from ..experimental.hooks.registry import get_registry +from ..experimental.hooks import ( + AfterModelInvocationEvent, + AfterToolInvocationEvent, + BeforeModelInvocationEvent, + BeforeToolInvocationEvent, + MessageAddedEvent, + get_registry, +) from ..telemetry.metrics import Trace from ..telemetry.tracer import get_tracer from ..tools.executor import run_tools, validate_and_prepare_tools @@ -115,6 +120,12 @@ async def event_loop_cycle(agent: "Agent", kwargs: dict[str, Any]) -> AsyncGener tool_specs = agent.tool_registry.get_all_tool_specs() + get_registry(agent).invoke_callbacks( + BeforeModelInvocationEvent( + agent=agent, + ) + ) + try: # TODO: To maintain backwards compatibility, we need to combine the stream event with kwargs before yielding # to the callback handler. This will be revisited when migrating to strongly typed events. @@ -125,40 +136,50 @@ async def event_loop_cycle(agent: "Agent", kwargs: dict[str, Any]) -> AsyncGener stop_reason, message, usage, metrics = event["stop"] kwargs.setdefault("request_state", {}) + get_registry(agent).invoke_callbacks( + AfterModelInvocationEvent( + agent=agent, + stop_response=AfterModelInvocationEvent.ModelStopResponse( + stop_reason=stop_reason, + message=message, + ), + ) + ) + if model_invoke_span: tracer.end_model_invoke_span(model_invoke_span, message, usage, stop_reason) break # Success! Break out of retry loop - except ContextWindowOverflowException as e: - if model_invoke_span: - tracer.end_span_with_error(model_invoke_span, str(e), e) - raise e - - except ModelThrottledException as e: + except Exception as e: if model_invoke_span: tracer.end_span_with_error(model_invoke_span, str(e), e) - if attempt + 1 == MAX_ATTEMPTS: - yield {"callback": {"force_stop": True, "force_stop_reason": str(e)}} - raise e - - logger.debug( - "retry_delay_seconds=<%s>, max_attempts=<%s>, current_attempt=<%s> " - "| throttling exception encountered " - "| delaying before next retry", - current_delay, - MAX_ATTEMPTS, - attempt + 1, + get_registry(agent).invoke_callbacks( + AfterModelInvocationEvent( + agent=agent, + exception=e, + ) ) - time.sleep(current_delay) - current_delay = min(current_delay * 2, MAX_DELAY) - yield {"callback": {"event_loop_throttled_delay": current_delay, **kwargs}} + if isinstance(e, ModelThrottledException): + if attempt + 1 == MAX_ATTEMPTS: + yield {"callback": {"force_stop": True, "force_stop_reason": str(e)}} + raise e - except Exception as e: - if model_invoke_span: - tracer.end_span_with_error(model_invoke_span, str(e), e) - raise e + logger.debug( + "retry_delay_seconds=<%s>, max_attempts=<%s>, current_attempt=<%s> " + "| throttling exception encountered " + "| delaying before next retry", + current_delay, + MAX_ATTEMPTS, + attempt + 1, + ) + time.sleep(current_delay) + current_delay = min(current_delay * 2, MAX_DELAY) + + yield {"callback": {"event_loop_throttled_delay": current_delay, **kwargs}} + else: + raise e try: # Add message in trace and mark the end of the stream messages trace diff --git a/src/strands/experimental/hooks/__init__.py b/src/strands/experimental/hooks/__init__.py index 32e4be9ad..e6264497c 100644 --- a/src/strands/experimental/hooks/__init__.py +++ b/src/strands/experimental/hooks/__init__.py @@ -30,8 +30,10 @@ def log_end(self, event: EndRequestEvent) -> None: """ from .events import ( + AfterModelInvocationEvent, AfterToolInvocationEvent, AgentInitializedEvent, + BeforeModelInvocationEvent, BeforeToolInvocationEvent, EndRequestEvent, MessageAddedEvent, @@ -43,6 +45,8 @@ def log_end(self, event: EndRequestEvent) -> None: "AgentInitializedEvent", "StartRequestEvent", "EndRequestEvent", + "BeforeModelInvocationEvent", + "AfterModelInvocationEvent", "BeforeToolInvocationEvent", "AfterToolInvocationEvent", "MessageAddedEvent", diff --git a/src/strands/experimental/hooks/events.py b/src/strands/experimental/hooks/events.py index 980f084cb..8dcec14d0 100644 --- a/src/strands/experimental/hooks/events.py +++ b/src/strands/experimental/hooks/events.py @@ -7,6 +7,7 @@ from typing import Any, Optional from ...types.content import Message +from ...types.streaming import StopReason from ...types.tools import AgentTool, ToolResult, ToolUse from .registry import HookEvent @@ -121,6 +122,59 @@ def should_reverse_callbacks(self) -> bool: return True +@dataclass +class BeforeModelInvocationEvent(HookEvent): + """Event triggered before the model is invoked. + + This event is fired just before the agent calls the model for inference, + allowing hook providers to inspect or modify the messages and configuration + that will be sent to the model. + + Note: This event is not fired for invocations to structured_output. + """ + + pass + + +@dataclass +class AfterModelInvocationEvent(HookEvent): + """Event triggered after the model invocation completes. + + This event is fired after the agent has finished calling the model, + regardless of whether the invocation was successful or resulted in an error. + Hook providers can use this event for cleanup, logging, or post-processing. + + Note: This event uses reverse callback ordering, meaning callbacks registered + later will be invoked first during cleanup. + + Note: This event is not fired for invocations to structured_output. + + Attributes: + stop_response: The model response data if invocation was successful, None if failed. + exception: Exception if the model invocation failed, None if successful. + """ + + @dataclass + class ModelStopResponse: + """Model response data from successful invocation. + + Attributes: + stop_reason: The reason the model stopped generating. + message: The generated message from the model. + """ + + message: Message + stop_reason: StopReason + + stop_response: Optional[ModelStopResponse] = None + exception: Optional[Exception] = None + + @property + def should_reverse_callbacks(self) -> bool: + """True to invoke callbacks in reverse order.""" + return True + + @dataclass class MessageAddedEvent(HookEvent): """Event triggered when a message is added to the agent's conversation. diff --git a/src/strands/experimental/hooks/rules.md b/src/strands/experimental/hooks/rules.md new file mode 100644 index 000000000..a55a71fa3 --- /dev/null +++ b/src/strands/experimental/hooks/rules.md @@ -0,0 +1,20 @@ +# Hook System Rules + +## Terminology + +- **Paired events**: Events that denote the beginning and end of an operation +- **Hook callback**: A function that receives a strongly-typed event argument and performs some action in response + +## Naming Conventions + +- All hook events have a suffix of `Event` +- Paired events follow the naming convention of `Before{Item}Event` and `After{Item}Event` + +## Paired Events + +- The final event in a pair returns `True` for `should_reverse_callbacks` +- For every `Before` event there is a corresponding `After` event, even if an exception occurs + +## Writable Properties + +For events with writable properties, those values are re-read after invoking the hook callbacks and used in subsequent processing. For example, `BeforeToolInvocationEvent.selected_tool` is writable - after invoking the callback for `BeforeToolInvocationEvent`, the `selected_tool` takes effect for the tool call. \ No newline at end of file diff --git a/tests/strands/agent/test_agent_hooks.py b/tests/strands/agent/test_agent_hooks.py index 8eb6a75b6..e7c74dfb9 100644 --- a/tests/strands/agent/test_agent_hooks.py +++ b/tests/strands/agent/test_agent_hooks.py @@ -6,8 +6,10 @@ import strands from strands import Agent from strands.experimental.hooks import ( + AfterModelInvocationEvent, AfterToolInvocationEvent, AgentInitializedEvent, + BeforeModelInvocationEvent, BeforeToolInvocationEvent, EndRequestEvent, MessageAddedEvent, @@ -29,6 +31,8 @@ def hook_provider(): EndRequestEvent, AfterToolInvocationEvent, BeforeToolInvocationEvent, + BeforeModelInvocationEvent, + AfterModelInvocationEvent, MessageAddedEvent, ] ) @@ -84,6 +88,11 @@ def assert_message_is_last_message_added(event: MessageAddedEvent): return agent +@pytest.fixture +def tools_config(agent): + return agent.tool_config["tools"] + + @pytest.fixture def user(): class User(BaseModel): @@ -131,20 +140,33 @@ def test_agent_tool_call(agent, hook_provider, agent_tool): assert len(agent.messages) == 4 -def test_agent__call__hooks(agent, hook_provider, agent_tool, tool_use): +def test_agent__call__hooks(agent, hook_provider, agent_tool, mock_model, tool_use): """Verify that the correct hook events are emitted as part of __call__.""" agent("test message") length, events = hook_provider.get_events() - assert length == 8 + assert length == 12 assert next(events) == StartRequestEvent(agent=agent) assert next(events) == MessageAddedEvent( agent=agent, message=agent.messages[0], ) + assert next(events) == BeforeModelInvocationEvent(agent=agent) + assert next(events) == AfterModelInvocationEvent( + agent=agent, + stop_response=AfterModelInvocationEvent.ModelStopResponse( + message={ + "content": [{"toolUse": tool_use}], + "role": "assistant", + }, + stop_reason="tool_use", + ), + exception=None, + ) + assert next(events) == MessageAddedEvent(agent=agent, message=agent.messages[1]) assert next(events) == BeforeToolInvocationEvent( agent=agent, selected_tool=agent_tool, tool_use=tool_use, kwargs=ANY @@ -157,14 +179,24 @@ def test_agent__call__hooks(agent, hook_provider, agent_tool, tool_use): result={"content": [{"text": "!loot a dekovni I"}], "status": "success", "toolUseId": "123"}, ) assert next(events) == MessageAddedEvent(agent=agent, message=agent.messages[2]) + assert next(events) == BeforeModelInvocationEvent(agent=agent) + assert next(events) == AfterModelInvocationEvent( + agent=agent, + stop_response=AfterModelInvocationEvent.ModelStopResponse( + message=mock_model.agent_responses[1], + stop_reason="end_turn", + ), + exception=None, + ) assert next(events) == MessageAddedEvent(agent=agent, message=agent.messages[3]) + assert next(events) == EndRequestEvent(agent=agent) assert len(agent.messages) == 4 @pytest.mark.asyncio -async def test_agent_stream_async_hooks(agent, hook_provider, agent_tool, tool_use): +async def test_agent_stream_async_hooks(agent, hook_provider, agent_tool, mock_model, tool_use, agenerator): """Verify that the correct hook events are emitted as part of stream_async.""" iterator = agent.stream_async("test message") await anext(iterator) @@ -176,13 +208,26 @@ async def test_agent_stream_async_hooks(agent, hook_provider, agent_tool, tool_u length, events = hook_provider.get_events() - assert length == 8 + assert length == 12 assert next(events) == StartRequestEvent(agent=agent) assert next(events) == MessageAddedEvent( agent=agent, message=agent.messages[0], ) + assert next(events) == BeforeModelInvocationEvent(agent=agent) + assert next(events) == AfterModelInvocationEvent( + agent=agent, + stop_response=AfterModelInvocationEvent.ModelStopResponse( + message={ + "content": [{"toolUse": tool_use}], + "role": "assistant", + }, + stop_reason="tool_use", + ), + exception=None, + ) + assert next(events) == MessageAddedEvent(agent=agent, message=agent.messages[1]) assert next(events) == BeforeToolInvocationEvent( agent=agent, selected_tool=agent_tool, tool_use=tool_use, kwargs=ANY @@ -195,7 +240,17 @@ async def test_agent_stream_async_hooks(agent, hook_provider, agent_tool, tool_u result={"content": [{"text": "!loot a dekovni I"}], "status": "success", "toolUseId": "123"}, ) assert next(events) == MessageAddedEvent(agent=agent, message=agent.messages[2]) + assert next(events) == BeforeModelInvocationEvent(agent=agent) + assert next(events) == AfterModelInvocationEvent( + agent=agent, + stop_response=AfterModelInvocationEvent.ModelStopResponse( + message=mock_model.agent_responses[1], + stop_reason="end_turn", + ), + exception=None, + ) assert next(events) == MessageAddedEvent(agent=agent, message=agent.messages[3]) + assert next(events) == EndRequestEvent(agent=agent) assert len(agent.messages) == 4 diff --git a/tests/strands/event_loop/test_event_loop.py b/tests/strands/event_loop/test_event_loop.py index 0d35fe28b..1c9c4f657 100644 --- a/tests/strands/event_loop/test_event_loop.py +++ b/tests/strands/event_loop/test_event_loop.py @@ -7,7 +7,14 @@ import strands import strands.telemetry from strands.event_loop.event_loop import run_tool -from strands.experimental.hooks import AfterToolInvocationEvent, BeforeToolInvocationEvent, HookProvider, HookRegistry +from strands.experimental.hooks import ( + AfterModelInvocationEvent, + AfterToolInvocationEvent, + BeforeModelInvocationEvent, + BeforeToolInvocationEvent, + HookProvider, + HookRegistry, +) from strands.telemetry.metrics import EventLoopMetrics from strands.tools.registry import ToolRegistry from strands.types.exceptions import ContextWindowOverflowException, EventLoopException, ModelThrottledException @@ -104,7 +111,14 @@ def hook_registry(): @pytest.fixture def hook_provider(hook_registry): - provider = MockHookProvider(event_types=[BeforeToolInvocationEvent, AfterToolInvocationEvent]) + provider = MockHookProvider( + event_types=[ + BeforeToolInvocationEvent, + AfterToolInvocationEvent, + BeforeModelInvocationEvent, + AfterModelInvocationEvent, + ] + ) hook_registry.add_hook(provider) return provider @@ -390,26 +404,6 @@ async def test_event_loop_cycle_tool_result_no_tool_handler( await alist(stream) -@pytest.mark.asyncio -async def test_event_loop_cycle_tool_result_no_tool_config( - agent, - model, - tool_stream, - agenerator, - alist, -): - model.converse.side_effect = [agenerator(tool_stream)] - # Set tool_config to None for this test - agent.tool_config = None - - with pytest.raises(EventLoopException): - stream = strands.event_loop.event_loop.event_loop_cycle( - agent=agent, - kwargs={}, - ) - await alist(stream) - - @pytest.mark.asyncio async def test_event_loop_cycle_stop( agent, @@ -1008,3 +1002,53 @@ def after_tool_call(self, event: AfterToolInvocationEvent): "test", ), ] + + +@pytest.mark.asyncio +async def test_event_loop_cycle_exception_model_hooks(mock_time, agent, model, agenerator, alist, hook_provider): + """Test that model hooks are correctly emitted even when throttled.""" + # Set up the model to raise throttling exceptions multiple times before succeeding + exception = ModelThrottledException("ThrottlingException | ConverseStream") + model.converse.side_effect = [ + exception, + exception, + exception, + agenerator( + [ + {"contentBlockDelta": {"delta": {"text": "test text"}}}, + {"contentBlockStop": {}}, + ] + ), + ] + + stream = strands.event_loop.event_loop.event_loop_cycle( + agent=agent, + kwargs={}, + ) + await alist(stream) + + count, events = hook_provider.get_events() + + assert count == 8 + + # 1st call - throttled + assert next(events) == BeforeModelInvocationEvent(agent=agent) + assert next(events) == AfterModelInvocationEvent(agent=agent, stop_response=None, exception=exception) + + # 2nd call - throttled + assert next(events) == BeforeModelInvocationEvent(agent=agent) + assert next(events) == AfterModelInvocationEvent(agent=agent, stop_response=None, exception=exception) + + # 3rd call - throttled + assert next(events) == BeforeModelInvocationEvent(agent=agent) + assert next(events) == AfterModelInvocationEvent(agent=agent, stop_response=None, exception=exception) + + # 4th call - successful + assert next(events) == BeforeModelInvocationEvent(agent=agent) + assert next(events) == AfterModelInvocationEvent( + agent=agent, + stop_response=AfterModelInvocationEvent.ModelStopResponse( + message={"content": [{"text": "test text"}], "role": "assistant"}, stop_reason="end_turn" + ), + exception=None, + )