Skip to content

Commit c412292

Browse files
authored
Now callers can subscribe and modify responses before they're sent to the model for processing. (#387)
We also don't support structured output as the parameters are not applicable
1 parent 513f32b commit c412292

File tree

6 files changed

+251
-53
lines changed

6 files changed

+251
-53
lines changed

src/strands/event_loop/event_loop.py

Lines changed: 48 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,14 @@
1313
import uuid
1414
from typing import TYPE_CHECKING, Any, AsyncGenerator, cast
1515

16-
from ..experimental.hooks import AfterToolInvocationEvent, BeforeToolInvocationEvent
17-
from ..experimental.hooks.events import MessageAddedEvent
18-
from ..experimental.hooks.registry import get_registry
16+
from ..experimental.hooks import (
17+
AfterModelInvocationEvent,
18+
AfterToolInvocationEvent,
19+
BeforeModelInvocationEvent,
20+
BeforeToolInvocationEvent,
21+
MessageAddedEvent,
22+
get_registry,
23+
)
1924
from ..telemetry.metrics import Trace
2025
from ..telemetry.tracer import get_tracer
2126
from ..tools.executor import run_tools, validate_and_prepare_tools
@@ -115,6 +120,12 @@ async def event_loop_cycle(agent: "Agent", kwargs: dict[str, Any]) -> AsyncGener
115120

116121
tool_specs = agent.tool_registry.get_all_tool_specs()
117122

123+
get_registry(agent).invoke_callbacks(
124+
BeforeModelInvocationEvent(
125+
agent=agent,
126+
)
127+
)
128+
118129
try:
119130
# TODO: To maintain backwards compatibility, we need to combine the stream event with kwargs before yielding
120131
# to the callback handler. This will be revisited when migrating to strongly typed events.
@@ -125,40 +136,50 @@ async def event_loop_cycle(agent: "Agent", kwargs: dict[str, Any]) -> AsyncGener
125136
stop_reason, message, usage, metrics = event["stop"]
126137
kwargs.setdefault("request_state", {})
127138

139+
get_registry(agent).invoke_callbacks(
140+
AfterModelInvocationEvent(
141+
agent=agent,
142+
stop_response=AfterModelInvocationEvent.ModelStopResponse(
143+
stop_reason=stop_reason,
144+
message=message,
145+
),
146+
)
147+
)
148+
128149
if model_invoke_span:
129150
tracer.end_model_invoke_span(model_invoke_span, message, usage, stop_reason)
130151
break # Success! Break out of retry loop
131152

132-
except ContextWindowOverflowException as e:
133-
if model_invoke_span:
134-
tracer.end_span_with_error(model_invoke_span, str(e), e)
135-
raise e
136-
137-
except ModelThrottledException as e:
153+
except Exception as e:
138154
if model_invoke_span:
139155
tracer.end_span_with_error(model_invoke_span, str(e), e)
140156

141-
if attempt + 1 == MAX_ATTEMPTS:
142-
yield {"callback": {"force_stop": True, "force_stop_reason": str(e)}}
143-
raise e
144-
145-
logger.debug(
146-
"retry_delay_seconds=<%s>, max_attempts=<%s>, current_attempt=<%s> "
147-
"| throttling exception encountered "
148-
"| delaying before next retry",
149-
current_delay,
150-
MAX_ATTEMPTS,
151-
attempt + 1,
157+
get_registry(agent).invoke_callbacks(
158+
AfterModelInvocationEvent(
159+
agent=agent,
160+
exception=e,
161+
)
152162
)
153-
time.sleep(current_delay)
154-
current_delay = min(current_delay * 2, MAX_DELAY)
155163

156-
yield {"callback": {"event_loop_throttled_delay": current_delay, **kwargs}}
164+
if isinstance(e, ModelThrottledException):
165+
if attempt + 1 == MAX_ATTEMPTS:
166+
yield {"callback": {"force_stop": True, "force_stop_reason": str(e)}}
167+
raise e
157168

158-
except Exception as e:
159-
if model_invoke_span:
160-
tracer.end_span_with_error(model_invoke_span, str(e), e)
161-
raise e
169+
logger.debug(
170+
"retry_delay_seconds=<%s>, max_attempts=<%s>, current_attempt=<%s> "
171+
"| throttling exception encountered "
172+
"| delaying before next retry",
173+
current_delay,
174+
MAX_ATTEMPTS,
175+
attempt + 1,
176+
)
177+
time.sleep(current_delay)
178+
current_delay = min(current_delay * 2, MAX_DELAY)
179+
180+
yield {"callback": {"event_loop_throttled_delay": current_delay, **kwargs}}
181+
else:
182+
raise e
162183

163184
try:
164185
# Add message in trace and mark the end of the stream messages trace

src/strands/experimental/hooks/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,10 @@ def log_end(self, event: EndRequestEvent) -> None:
3030
"""
3131

3232
from .events import (
33+
AfterModelInvocationEvent,
3334
AfterToolInvocationEvent,
3435
AgentInitializedEvent,
36+
BeforeModelInvocationEvent,
3537
BeforeToolInvocationEvent,
3638
EndRequestEvent,
3739
MessageAddedEvent,
@@ -43,6 +45,8 @@ def log_end(self, event: EndRequestEvent) -> None:
4345
"AgentInitializedEvent",
4446
"StartRequestEvent",
4547
"EndRequestEvent",
48+
"BeforeModelInvocationEvent",
49+
"AfterModelInvocationEvent",
4650
"BeforeToolInvocationEvent",
4751
"AfterToolInvocationEvent",
4852
"MessageAddedEvent",

src/strands/experimental/hooks/events.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from typing import Any, Optional
88

99
from ...types.content import Message
10+
from ...types.streaming import StopReason
1011
from ...types.tools import AgentTool, ToolResult, ToolUse
1112
from .registry import HookEvent
1213

@@ -121,6 +122,59 @@ def should_reverse_callbacks(self) -> bool:
121122
return True
122123

123124

125+
@dataclass
126+
class BeforeModelInvocationEvent(HookEvent):
127+
"""Event triggered before the model is invoked.
128+
129+
This event is fired just before the agent calls the model for inference,
130+
allowing hook providers to inspect or modify the messages and configuration
131+
that will be sent to the model.
132+
133+
Note: This event is not fired for invocations to structured_output.
134+
"""
135+
136+
pass
137+
138+
139+
@dataclass
140+
class AfterModelInvocationEvent(HookEvent):
141+
"""Event triggered after the model invocation completes.
142+
143+
This event is fired after the agent has finished calling the model,
144+
regardless of whether the invocation was successful or resulted in an error.
145+
Hook providers can use this event for cleanup, logging, or post-processing.
146+
147+
Note: This event uses reverse callback ordering, meaning callbacks registered
148+
later will be invoked first during cleanup.
149+
150+
Note: This event is not fired for invocations to structured_output.
151+
152+
Attributes:
153+
stop_response: The model response data if invocation was successful, None if failed.
154+
exception: Exception if the model invocation failed, None if successful.
155+
"""
156+
157+
@dataclass
158+
class ModelStopResponse:
159+
"""Model response data from successful invocation.
160+
161+
Attributes:
162+
stop_reason: The reason the model stopped generating.
163+
message: The generated message from the model.
164+
"""
165+
166+
message: Message
167+
stop_reason: StopReason
168+
169+
stop_response: Optional[ModelStopResponse] = None
170+
exception: Optional[Exception] = None
171+
172+
@property
173+
def should_reverse_callbacks(self) -> bool:
174+
"""True to invoke callbacks in reverse order."""
175+
return True
176+
177+
124178
@dataclass
125179
class MessageAddedEvent(HookEvent):
126180
"""Event triggered when a message is added to the agent's conversation.
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
# Hook System Rules
2+
3+
## Terminology
4+
5+
- **Paired events**: Events that denote the beginning and end of an operation
6+
- **Hook callback**: A function that receives a strongly-typed event argument and performs some action in response
7+
8+
## Naming Conventions
9+
10+
- All hook events have a suffix of `Event`
11+
- Paired events follow the naming convention of `Before{Item}Event` and `After{Item}Event`
12+
13+
## Paired Events
14+
15+
- The final event in a pair returns `True` for `should_reverse_callbacks`
16+
- For every `Before` event there is a corresponding `After` event, even if an exception occurs
17+
18+
## Writable Properties
19+
20+
For events with writable properties, those values are re-read after invoking the hook callbacks and used in subsequent processing. For example, `BeforeToolInvocationEvent.selected_tool` is writable - after invoking the callback for `BeforeToolInvocationEvent`, the `selected_tool` takes effect for the tool call.

tests/strands/agent/test_agent_hooks.py

Lines changed: 59 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,10 @@
66
import strands
77
from strands import Agent
88
from strands.experimental.hooks import (
9+
AfterModelInvocationEvent,
910
AfterToolInvocationEvent,
1011
AgentInitializedEvent,
12+
BeforeModelInvocationEvent,
1113
BeforeToolInvocationEvent,
1214
EndRequestEvent,
1315
MessageAddedEvent,
@@ -29,6 +31,8 @@ def hook_provider():
2931
EndRequestEvent,
3032
AfterToolInvocationEvent,
3133
BeforeToolInvocationEvent,
34+
BeforeModelInvocationEvent,
35+
AfterModelInvocationEvent,
3236
MessageAddedEvent,
3337
]
3438
)
@@ -84,6 +88,11 @@ def assert_message_is_last_message_added(event: MessageAddedEvent):
8488
return agent
8589

8690

91+
@pytest.fixture
92+
def tools_config(agent):
93+
return agent.tool_config["tools"]
94+
95+
8796
@pytest.fixture
8897
def user():
8998
class User(BaseModel):
@@ -131,20 +140,33 @@ def test_agent_tool_call(agent, hook_provider, agent_tool):
131140
assert len(agent.messages) == 4
132141

133142

134-
def test_agent__call__hooks(agent, hook_provider, agent_tool, tool_use):
143+
def test_agent__call__hooks(agent, hook_provider, agent_tool, mock_model, tool_use):
135144
"""Verify that the correct hook events are emitted as part of __call__."""
136145

137146
agent("test message")
138147

139148
length, events = hook_provider.get_events()
140149

141-
assert length == 8
150+
assert length == 12
142151

143152
assert next(events) == StartRequestEvent(agent=agent)
144153
assert next(events) == MessageAddedEvent(
145154
agent=agent,
146155
message=agent.messages[0],
147156
)
157+
assert next(events) == BeforeModelInvocationEvent(agent=agent)
158+
assert next(events) == AfterModelInvocationEvent(
159+
agent=agent,
160+
stop_response=AfterModelInvocationEvent.ModelStopResponse(
161+
message={
162+
"content": [{"toolUse": tool_use}],
163+
"role": "assistant",
164+
},
165+
stop_reason="tool_use",
166+
),
167+
exception=None,
168+
)
169+
148170
assert next(events) == MessageAddedEvent(agent=agent, message=agent.messages[1])
149171
assert next(events) == BeforeToolInvocationEvent(
150172
agent=agent, selected_tool=agent_tool, tool_use=tool_use, kwargs=ANY
@@ -157,14 +179,24 @@ def test_agent__call__hooks(agent, hook_provider, agent_tool, tool_use):
157179
result={"content": [{"text": "!loot a dekovni I"}], "status": "success", "toolUseId": "123"},
158180
)
159181
assert next(events) == MessageAddedEvent(agent=agent, message=agent.messages[2])
182+
assert next(events) == BeforeModelInvocationEvent(agent=agent)
183+
assert next(events) == AfterModelInvocationEvent(
184+
agent=agent,
185+
stop_response=AfterModelInvocationEvent.ModelStopResponse(
186+
message=mock_model.agent_responses[1],
187+
stop_reason="end_turn",
188+
),
189+
exception=None,
190+
)
160191
assert next(events) == MessageAddedEvent(agent=agent, message=agent.messages[3])
192+
161193
assert next(events) == EndRequestEvent(agent=agent)
162194

163195
assert len(agent.messages) == 4
164196

165197

166198
@pytest.mark.asyncio
167-
async def test_agent_stream_async_hooks(agent, hook_provider, agent_tool, tool_use):
199+
async def test_agent_stream_async_hooks(agent, hook_provider, agent_tool, mock_model, tool_use, agenerator):
168200
"""Verify that the correct hook events are emitted as part of stream_async."""
169201
iterator = agent.stream_async("test message")
170202
await anext(iterator)
@@ -176,13 +208,26 @@ async def test_agent_stream_async_hooks(agent, hook_provider, agent_tool, tool_u
176208

177209
length, events = hook_provider.get_events()
178210

179-
assert length == 8
211+
assert length == 12
180212

181213
assert next(events) == StartRequestEvent(agent=agent)
182214
assert next(events) == MessageAddedEvent(
183215
agent=agent,
184216
message=agent.messages[0],
185217
)
218+
assert next(events) == BeforeModelInvocationEvent(agent=agent)
219+
assert next(events) == AfterModelInvocationEvent(
220+
agent=agent,
221+
stop_response=AfterModelInvocationEvent.ModelStopResponse(
222+
message={
223+
"content": [{"toolUse": tool_use}],
224+
"role": "assistant",
225+
},
226+
stop_reason="tool_use",
227+
),
228+
exception=None,
229+
)
230+
186231
assert next(events) == MessageAddedEvent(agent=agent, message=agent.messages[1])
187232
assert next(events) == BeforeToolInvocationEvent(
188233
agent=agent, selected_tool=agent_tool, tool_use=tool_use, kwargs=ANY
@@ -195,7 +240,17 @@ async def test_agent_stream_async_hooks(agent, hook_provider, agent_tool, tool_u
195240
result={"content": [{"text": "!loot a dekovni I"}], "status": "success", "toolUseId": "123"},
196241
)
197242
assert next(events) == MessageAddedEvent(agent=agent, message=agent.messages[2])
243+
assert next(events) == BeforeModelInvocationEvent(agent=agent)
244+
assert next(events) == AfterModelInvocationEvent(
245+
agent=agent,
246+
stop_response=AfterModelInvocationEvent.ModelStopResponse(
247+
message=mock_model.agent_responses[1],
248+
stop_reason="end_turn",
249+
),
250+
exception=None,
251+
)
198252
assert next(events) == MessageAddedEvent(agent=agent, message=agent.messages[3])
253+
199254
assert next(events) == EndRequestEvent(agent=agent)
200255

201256
assert len(agent.messages) == 4

0 commit comments

Comments
 (0)