|  | 
| 2 | 2 | 
 | 
| 3 | 3 | import pytest | 
| 4 | 4 | 
 | 
|  | 5 | +from strands import Agent, tool | 
| 5 | 6 | from strands.agent.state import AgentState | 
|  | 7 | +from strands.types.content import Messages | 
|  | 8 | + | 
|  | 9 | +from ...fixtures.mocked_model_provider import MockedModelProvider | 
| 6 | 10 | 
 | 
| 7 | 11 | 
 | 
| 8 | 12 | def test_set_and_get(): | 
| @@ -109,3 +113,33 @@ def test_initial_state(): | 
| 109 | 113 |     assert state.get("key1") == "value1" | 
| 110 | 114 |     assert state.get("key2") == "value2" | 
| 111 | 115 |     assert state.get() == initial | 
|  | 116 | + | 
|  | 117 | + | 
|  | 118 | +def test_agent_state_update_from_tool(): | 
|  | 119 | +    @tool | 
|  | 120 | +    def update_state(agent: Agent): | 
|  | 121 | +        agent.state.set("hello", "world") | 
|  | 122 | +        agent.state.set("foo", "baz") | 
|  | 123 | + | 
|  | 124 | +    agent_messages: Messages = [ | 
|  | 125 | +        { | 
|  | 126 | +            "role": "assistant", | 
|  | 127 | +            "content": [{"toolUse": {"name": "update_state", "toolUseId": "123", "input": {}}}], | 
|  | 128 | +        }, | 
|  | 129 | +        {"role": "assistant", "content": [{"text": "I invoked a tool!"}]}, | 
|  | 130 | +    ] | 
|  | 131 | +    mocked_model_provider = MockedModelProvider(agent_messages) | 
|  | 132 | + | 
|  | 133 | +    agent = Agent( | 
|  | 134 | +        model=mocked_model_provider, | 
|  | 135 | +        tools=[update_state], | 
|  | 136 | +        state={"foo": "bar"}, | 
|  | 137 | +    ) | 
|  | 138 | + | 
|  | 139 | +    assert agent.state.get("hello") is None | 
|  | 140 | +    assert agent.state.get("foo") == "bar" | 
|  | 141 | + | 
|  | 142 | +    agent("Invoke Mocked!") | 
|  | 143 | + | 
|  | 144 | +    assert agent.state.get("hello") == "world" | 
|  | 145 | +    assert agent.state.get("foo") == "baz" | 
0 commit comments