66import strands
77from strands import Agent
88from strands .experimental .hooks import (
9+ AfterModelInvocationEvent ,
910 AfterToolInvocationEvent ,
1011 AgentInitializedEvent ,
12+ BeforeModelInvocationEvent ,
1113 BeforeToolInvocationEvent ,
1214 EndRequestEvent ,
1315 MessageAddedEvent ,
@@ -29,6 +31,8 @@ def hook_provider():
2931 EndRequestEvent ,
3032 AfterToolInvocationEvent ,
3133 BeforeToolInvocationEvent ,
34+ BeforeModelInvocationEvent ,
35+ AfterModelInvocationEvent ,
3236 MessageAddedEvent ,
3337 ]
3438 )
@@ -84,6 +88,11 @@ def assert_message_is_last_message_added(event: MessageAddedEvent):
8488 return agent
8589
8690
91+ @pytest .fixture
92+ def tools_config (agent ):
93+ return agent .tool_config ["tools" ]
94+
95+
8796@pytest .fixture
8897def user ():
8998 class User (BaseModel ):
@@ -131,20 +140,33 @@ def test_agent_tool_call(agent, hook_provider, agent_tool):
131140 assert len (agent .messages ) == 4
132141
133142
134- def test_agent__call__hooks (agent , hook_provider , agent_tool , tool_use ):
143+ def test_agent__call__hooks (agent , hook_provider , agent_tool , mock_model , tool_use ):
135144 """Verify that the correct hook events are emitted as part of __call__."""
136145
137146 agent ("test message" )
138147
139148 length , events = hook_provider .get_events ()
140149
141- assert length == 8
150+ assert length == 12
142151
143152 assert next (events ) == StartRequestEvent (agent = agent )
144153 assert next (events ) == MessageAddedEvent (
145154 agent = agent ,
146155 message = agent .messages [0 ],
147156 )
157+ assert next (events ) == BeforeModelInvocationEvent (agent = agent )
158+ assert next (events ) == AfterModelInvocationEvent (
159+ agent = agent ,
160+ stop_response = AfterModelInvocationEvent .ModelStopResponse (
161+ message = {
162+ "content" : [{"toolUse" : tool_use }],
163+ "role" : "assistant" ,
164+ },
165+ stop_reason = "tool_use" ,
166+ ),
167+ exception = None ,
168+ )
169+
148170 assert next (events ) == MessageAddedEvent (agent = agent , message = agent .messages [1 ])
149171 assert next (events ) == BeforeToolInvocationEvent (
150172 agent = agent , selected_tool = agent_tool , tool_use = tool_use , kwargs = ANY
@@ -157,14 +179,24 @@ def test_agent__call__hooks(agent, hook_provider, agent_tool, tool_use):
157179 result = {"content" : [{"text" : "!loot a dekovni I" }], "status" : "success" , "toolUseId" : "123" },
158180 )
159181 assert next (events ) == MessageAddedEvent (agent = agent , message = agent .messages [2 ])
182+ assert next (events ) == BeforeModelInvocationEvent (agent = agent )
183+ assert next (events ) == AfterModelInvocationEvent (
184+ agent = agent ,
185+ stop_response = AfterModelInvocationEvent .ModelStopResponse (
186+ message = mock_model .agent_responses [1 ],
187+ stop_reason = "end_turn" ,
188+ ),
189+ exception = None ,
190+ )
160191 assert next (events ) == MessageAddedEvent (agent = agent , message = agent .messages [3 ])
192+
161193 assert next (events ) == EndRequestEvent (agent = agent )
162194
163195 assert len (agent .messages ) == 4
164196
165197
166198@pytest .mark .asyncio
167- async def test_agent_stream_async_hooks (agent , hook_provider , agent_tool , tool_use ):
199+ async def test_agent_stream_async_hooks (agent , hook_provider , agent_tool , mock_model , tool_use , agenerator ):
168200 """Verify that the correct hook events are emitted as part of stream_async."""
169201 iterator = agent .stream_async ("test message" )
170202 await anext (iterator )
@@ -176,13 +208,26 @@ async def test_agent_stream_async_hooks(agent, hook_provider, agent_tool, tool_u
176208
177209 length , events = hook_provider .get_events ()
178210
179- assert length == 8
211+ assert length == 12
180212
181213 assert next (events ) == StartRequestEvent (agent = agent )
182214 assert next (events ) == MessageAddedEvent (
183215 agent = agent ,
184216 message = agent .messages [0 ],
185217 )
218+ assert next (events ) == BeforeModelInvocationEvent (agent = agent )
219+ assert next (events ) == AfterModelInvocationEvent (
220+ agent = agent ,
221+ stop_response = AfterModelInvocationEvent .ModelStopResponse (
222+ message = {
223+ "content" : [{"toolUse" : tool_use }],
224+ "role" : "assistant" ,
225+ },
226+ stop_reason = "tool_use" ,
227+ ),
228+ exception = None ,
229+ )
230+
186231 assert next (events ) == MessageAddedEvent (agent = agent , message = agent .messages [1 ])
187232 assert next (events ) == BeforeToolInvocationEvent (
188233 agent = agent , selected_tool = agent_tool , tool_use = tool_use , kwargs = ANY
@@ -195,7 +240,17 @@ async def test_agent_stream_async_hooks(agent, hook_provider, agent_tool, tool_u
195240 result = {"content" : [{"text" : "!loot a dekovni I" }], "status" : "success" , "toolUseId" : "123" },
196241 )
197242 assert next (events ) == MessageAddedEvent (agent = agent , message = agent .messages [2 ])
243+ assert next (events ) == BeforeModelInvocationEvent (agent = agent )
244+ assert next (events ) == AfterModelInvocationEvent (
245+ agent = agent ,
246+ stop_response = AfterModelInvocationEvent .ModelStopResponse (
247+ message = mock_model .agent_responses [1 ],
248+ stop_reason = "end_turn" ,
249+ ),
250+ exception = None ,
251+ )
198252 assert next (events ) == MessageAddedEvent (agent = agent , message = agent .messages [3 ])
253+
199254 assert next (events ) == EndRequestEvent (agent = agent )
200255
201256 assert len (agent .messages ) == 4
0 commit comments