From c5e4e51a0392fb921a280a8891de40398927fe98 Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Wed, 30 Jul 2025 16:31:46 -0400 Subject: [PATCH 01/23] fix(event_loop): raise dedicated exception when encountering max tokens stop reason --- src/strands/event_loop/event_loop.py | 15 ++++++- src/strands/types/exceptions.py | 11 +++++ tests/strands/event_loop/test_event_loop.py | 48 ++++++++++++++++++++- tests_integ/test_max_tokens_reached.py | 18 ++++++++ 4 files changed, 90 insertions(+), 2 deletions(-) create mode 100644 tests_integ/test_max_tokens_reached.py diff --git a/src/strands/event_loop/event_loop.py b/src/strands/event_loop/event_loop.py index ffcb6a5c9..5b96dfc92 100644 --- a/src/strands/event_loop/event_loop.py +++ b/src/strands/event_loop/event_loop.py @@ -28,7 +28,12 @@ from ..telemetry.tracer import get_tracer from ..tools.executor import run_tools, validate_and_prepare_tools from ..types.content import Message -from ..types.exceptions import ContextWindowOverflowException, EventLoopException, ModelThrottledException +from ..types.exceptions import ( + ContextWindowOverflowException, + EventLoopException, + EventLoopMaxTokensReachedException, + ModelThrottledException, +) from ..types.streaming import Metrics, StopReason from ..types.tools import ToolChoice, ToolChoiceAuto, ToolConfig, ToolGenerator, ToolResult, ToolUse from .streaming import stream_messages @@ -216,6 +221,14 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) -> yield event return + elif stop_reason == "max_tokens": + raise EventLoopMaxTokensReachedException( + ( + "Agent has reached an unrecoverable state due to max_tokens limit. " + "For more information see: " + "https://strandsagents.com/latest/user-guide/concepts/agents/agent-loop/#maxtokensreachedexception" + ) + ) # End the cycle and return results agent.event_loop_metrics.end_cycle(cycle_start_time, cycle_trace, attributes) diff --git a/src/strands/types/exceptions.py b/src/strands/types/exceptions.py index 4bd3fd88e..14f76e945 100644 --- a/src/strands/types/exceptions.py +++ b/src/strands/types/exceptions.py @@ -18,6 +18,17 @@ def __init__(self, original_exception: Exception, request_state: Any = None) -> super().__init__(str(original_exception)) +class EventLoopMaxTokensReachedException(EventLoopException): + """Exception raised when the model reaches its maximum token generation limit. + + This exception is raised when the model stops generating tokens because it has reached the maximum number of + tokens allowed for output generation. This can occur when the model's max_tokens parameter is set too low for + the complexity of the response, or when the model naturally reaches its configured output limit during generation. + """ + + pass + + class ContextWindowOverflowException(Exception): """Exception raised when the context window is exceeded. diff --git a/tests/strands/event_loop/test_event_loop.py b/tests/strands/event_loop/test_event_loop.py index 1ac2f8258..3303b7282 100644 --- a/tests/strands/event_loop/test_event_loop.py +++ b/tests/strands/event_loop/test_event_loop.py @@ -19,7 +19,12 @@ ) from strands.telemetry.metrics import EventLoopMetrics from strands.tools.registry import ToolRegistry -from strands.types.exceptions import ContextWindowOverflowException, EventLoopException, ModelThrottledException +from strands.types.exceptions import ( + ContextWindowOverflowException, + EventLoopException, + EventLoopMaxTokensReachedException, + ModelThrottledException, +) from tests.fixtures.mock_hook_provider import MockHookProvider @@ -556,6 +561,47 @@ async def test_event_loop_tracing_with_model_error( mock_tracer.end_span_with_error.assert_called_once_with(model_span, "Input too long", model.stream.side_effect) +@pytest.mark.asyncio +async def test_event_loop_cycle_max_tokens_exception( + agent, + model, + agenerator, + alist, +): + """Test that max_tokens stop reason raises MaxTokensReachedException.""" + + # Note the empty toolUse to handle case raised in https://github.com/strands-agents/sdk-python/issues/495 + model.stream.return_value = agenerator( + [ + { + "contentBlockStart": { + "start": { + "toolUse": {}, + }, + }, + }, + {"contentBlockStop": {}}, + {"messageStop": {"stopReason": "max_tokens"}}, + ] + ) + + # Call event_loop_cycle, expecting it to raise MaxTokensReachedException + with pytest.raises(EventLoopMaxTokensReachedException) as exc_info: + stream = strands.event_loop.event_loop.event_loop_cycle( + agent=agent, + invocation_state={}, + ) + await alist(stream) + + # Verify the exception message contains the expected content + expected_message = ( + "Agent has reached an unrecoverable state due to max_tokens limit. " + "For more information see: " + "https://strandsagents.com/latest/user-guide/concepts/agents/agent-loop/#maxtokensreachedexception" + ) + assert str(exc_info.value) == expected_message + + @patch("strands.event_loop.event_loop.get_tracer") @pytest.mark.asyncio async def test_event_loop_tracing_with_tool_execution( diff --git a/tests_integ/test_max_tokens_reached.py b/tests_integ/test_max_tokens_reached.py new file mode 100644 index 000000000..b6f6b2857 --- /dev/null +++ b/tests_integ/test_max_tokens_reached.py @@ -0,0 +1,18 @@ +import pytest + +from strands import Agent, tool +from strands.models.bedrock import BedrockModel +from strands.types.exceptions import EventLoopMaxTokensReachedException + + +@tool +def story_tool(story: str) -> str: + return story + + +def test_context_window_overflow(): + model = BedrockModel(max_tokens=1) + agent = Agent(model=model, tools=[story_tool]) + + with pytest.raises(EventLoopMaxTokensReachedException): + agent("Tell me a story!") From 6703819d6b6cdedb7b08d92e028bb3deca6c4e78 Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Wed, 30 Jul 2025 17:02:03 -0400 Subject: [PATCH 02/23] fix: update integ tests --- src/strands/event_loop/event_loop.py | 2 +- src/strands/models/anthropic.py | 2 +- src/strands/models/bedrock.py | 2 +- src/strands/types/exceptions.py | 2 +- tests/strands/event_loop/test_event_loop.py | 9 ++++----- tests_integ/test_max_tokens_reached.py | 7 ++++--- 6 files changed, 12 insertions(+), 12 deletions(-) diff --git a/src/strands/event_loop/event_loop.py b/src/strands/event_loop/event_loop.py index 5b96dfc92..16fefa5ac 100644 --- a/src/strands/event_loop/event_loop.py +++ b/src/strands/event_loop/event_loop.py @@ -226,7 +226,7 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) -> ( "Agent has reached an unrecoverable state due to max_tokens limit. " "For more information see: " - "https://strandsagents.com/latest/user-guide/concepts/agents/agent-loop/#maxtokensreachedexception" + "https://strandsagents.com/latest/user-guide/concepts/agents/agent-loop/#eventloopmaxtokensreachedexception" ) ) diff --git a/src/strands/models/anthropic.py b/src/strands/models/anthropic.py index 0d734b762..975fca3e9 100644 --- a/src/strands/models/anthropic.py +++ b/src/strands/models/anthropic.py @@ -414,7 +414,7 @@ async def structured_output( stop_reason, messages, _, _ = event["stop"] if stop_reason != "tool_use": - raise ValueError(f"Model returned stop_reason: {stop_reason} instead of \"tool_use\".") + raise ValueError(f'Model returned stop_reason: {stop_reason} instead of "tool_use".') content = messages["content"] output_response: dict[str, Any] | None = None diff --git a/src/strands/models/bedrock.py b/src/strands/models/bedrock.py index 9b36b4244..4ea1453a4 100644 --- a/src/strands/models/bedrock.py +++ b/src/strands/models/bedrock.py @@ -631,7 +631,7 @@ async def structured_output( stop_reason, messages, _, _ = event["stop"] if stop_reason != "tool_use": - raise ValueError(f"Model returned stop_reason: {stop_reason} instead of \"tool_use\".") + raise ValueError(f'Model returned stop_reason: {stop_reason} instead of "tool_use".') content = messages["content"] output_response: dict[str, Any] | None = None diff --git a/src/strands/types/exceptions.py b/src/strands/types/exceptions.py index 14f76e945..7d9f1c6dc 100644 --- a/src/strands/types/exceptions.py +++ b/src/strands/types/exceptions.py @@ -18,7 +18,7 @@ def __init__(self, original_exception: Exception, request_state: Any = None) -> super().__init__(str(original_exception)) -class EventLoopMaxTokensReachedException(EventLoopException): +class EventLoopMaxTokensReachedException(Exception): """Exception raised when the model reaches its maximum token generation limit. This exception is raised when the model stops generating tokens because it has reached the maximum number of diff --git a/tests/strands/event_loop/test_event_loop.py b/tests/strands/event_loop/test_event_loop.py index 3303b7282..05b20ba01 100644 --- a/tests/strands/event_loop/test_event_loop.py +++ b/tests/strands/event_loop/test_event_loop.py @@ -22,7 +22,6 @@ from strands.types.exceptions import ( ContextWindowOverflowException, EventLoopException, - EventLoopMaxTokensReachedException, ModelThrottledException, ) from tests.fixtures.mock_hook_provider import MockHookProvider @@ -568,7 +567,7 @@ async def test_event_loop_cycle_max_tokens_exception( agenerator, alist, ): - """Test that max_tokens stop reason raises MaxTokensReachedException.""" + """Test that max_tokens stop reason raises EventLoopMaxTokensReachedException.""" # Note the empty toolUse to handle case raised in https://github.com/strands-agents/sdk-python/issues/495 model.stream.return_value = agenerator( @@ -585,8 +584,8 @@ async def test_event_loop_cycle_max_tokens_exception( ] ) - # Call event_loop_cycle, expecting it to raise MaxTokensReachedException - with pytest.raises(EventLoopMaxTokensReachedException) as exc_info: + # Call event_loop_cycle, expecting it to raise EventLoopMaxTokensReachedException + with pytest.raises(EventLoopException) as exc_info: stream = strands.event_loop.event_loop.event_loop_cycle( agent=agent, invocation_state={}, @@ -597,7 +596,7 @@ async def test_event_loop_cycle_max_tokens_exception( expected_message = ( "Agent has reached an unrecoverable state due to max_tokens limit. " "For more information see: " - "https://strandsagents.com/latest/user-guide/concepts/agents/agent-loop/#maxtokensreachedexception" + "https://strandsagents.com/latest/user-guide/concepts/agents/agent-loop/#eventloopmaxtokensreachedexception" ) assert str(exc_info.value) == expected_message diff --git a/tests_integ/test_max_tokens_reached.py b/tests_integ/test_max_tokens_reached.py index b6f6b2857..1bf75f136 100644 --- a/tests_integ/test_max_tokens_reached.py +++ b/tests_integ/test_max_tokens_reached.py @@ -1,8 +1,7 @@ -import pytest from strands import Agent, tool from strands.models.bedrock import BedrockModel -from strands.types.exceptions import EventLoopMaxTokensReachedException +from strands.types.exceptions import EventLoopException, EventLoopMaxTokensReachedException @tool @@ -14,5 +13,7 @@ def test_context_window_overflow(): model = BedrockModel(max_tokens=1) agent = Agent(model=model, tools=[story_tool]) - with pytest.raises(EventLoopMaxTokensReachedException): + try: agent("Tell me a story!") + except EventLoopException as e: + assert isinstance(e.original_exception, EventLoopMaxTokensReachedException) From c94b74e75236dcbac0ffdb438f3a4a9ff59cda5f Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Thu, 31 Jul 2025 10:50:40 -0400 Subject: [PATCH 03/23] fix: rename exception message, add to exception, move earlier in cycle --- src/strands/event_loop/event_loop.py | 29 ++++++++++++++------- src/strands/types/exceptions.py | 14 ++++++++-- tests/strands/event_loop/test_event_loop.py | 13 ++++++--- tests_integ/test_max_tokens_reached.py | 7 +++-- 4 files changed, 43 insertions(+), 20 deletions(-) diff --git a/src/strands/event_loop/event_loop.py b/src/strands/event_loop/event_loop.py index 16fefa5ac..ae21d4c6d 100644 --- a/src/strands/event_loop/event_loop.py +++ b/src/strands/event_loop/event_loop.py @@ -31,7 +31,7 @@ from ..types.exceptions import ( ContextWindowOverflowException, EventLoopException, - EventLoopMaxTokensReachedException, + MaxTokensReachedException, ModelThrottledException, ) from ..types.streaming import Metrics, StopReason @@ -192,6 +192,22 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) -> raise e try: + if stop_reason == "max_tokens": + """ + Handle max_tokens limit reached by the model. + + When the model reaches its maximum token limit, this represents a potentially unrecoverable + state where the model's response was truncated. By default, Strands fails hard with an + MaxTokensReachedException to maintain consistency with other failure types. + """ + raise MaxTokensReachedException( + message=( + "Agent has reached an unrecoverable state due to max_tokens limit. " + "For more information see: " + "https://strandsagents.com/latest/user-guide/concepts/agents/agent-loop/#maxtokensreachedexception" + ), + incomplete_message=message, + ) # Add message in trace and mark the end of the stream messages trace stream_trace.add_message(message) stream_trace.end() @@ -221,14 +237,6 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) -> yield event return - elif stop_reason == "max_tokens": - raise EventLoopMaxTokensReachedException( - ( - "Agent has reached an unrecoverable state due to max_tokens limit. " - "For more information see: " - "https://strandsagents.com/latest/user-guide/concepts/agents/agent-loop/#eventloopmaxtokensreachedexception" - ) - ) # End the cycle and return results agent.event_loop_metrics.end_cycle(cycle_start_time, cycle_trace, attributes) @@ -244,7 +252,8 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) -> # Don't yield or log the exception - we already did it when we # raised the exception and we don't need that duplication. raise - except ContextWindowOverflowException as e: + except (ContextWindowOverflowException, MaxTokensReachedException) as e: + # Special cased exceptions which we want to bubble up rather than get wrapped in an EventLoopException if cycle_span: tracer.end_span_with_error(cycle_span, str(e), e) raise e diff --git a/src/strands/types/exceptions.py b/src/strands/types/exceptions.py index 7d9f1c6dc..71ea28b9f 100644 --- a/src/strands/types/exceptions.py +++ b/src/strands/types/exceptions.py @@ -2,6 +2,8 @@ from typing import Any +from strands.types.content import Message + class EventLoopException(Exception): """Exception raised by the event loop.""" @@ -18,7 +20,7 @@ def __init__(self, original_exception: Exception, request_state: Any = None) -> super().__init__(str(original_exception)) -class EventLoopMaxTokensReachedException(Exception): +class MaxTokensReachedException(Exception): """Exception raised when the model reaches its maximum token generation limit. This exception is raised when the model stops generating tokens because it has reached the maximum number of @@ -26,7 +28,15 @@ class EventLoopMaxTokensReachedException(Exception): the complexity of the response, or when the model naturally reaches its configured output limit during generation. """ - pass + def __init__(self, message: str, incomplete_message: Message): + """Initialize the exception with an error message and the incomplete message object. + + Args: + message: The error message describing the token limit issue + incomplete_message: The valid Message object with incomplete content due to token limits + """ + self.incomplete_message = incomplete_message + super().__init__(message) class ContextWindowOverflowException(Exception): diff --git a/tests/strands/event_loop/test_event_loop.py b/tests/strands/event_loop/test_event_loop.py index 05b20ba01..3886df8b9 100644 --- a/tests/strands/event_loop/test_event_loop.py +++ b/tests/strands/event_loop/test_event_loop.py @@ -22,6 +22,7 @@ from strands.types.exceptions import ( ContextWindowOverflowException, EventLoopException, + MaxTokensReachedException, ModelThrottledException, ) from tests.fixtures.mock_hook_provider import MockHookProvider @@ -567,7 +568,7 @@ async def test_event_loop_cycle_max_tokens_exception( agenerator, alist, ): - """Test that max_tokens stop reason raises EventLoopMaxTokensReachedException.""" + """Test that max_tokens stop reason raises MaxTokensReachedException.""" # Note the empty toolUse to handle case raised in https://github.com/strands-agents/sdk-python/issues/495 model.stream.return_value = agenerator( @@ -584,8 +585,8 @@ async def test_event_loop_cycle_max_tokens_exception( ] ) - # Call event_loop_cycle, expecting it to raise EventLoopMaxTokensReachedException - with pytest.raises(EventLoopException) as exc_info: + # Call event_loop_cycle, expecting it to raise MaxTokensReachedException + with pytest.raises(MaxTokensReachedException) as exc_info: stream = strands.event_loop.event_loop.event_loop_cycle( agent=agent, invocation_state={}, @@ -596,10 +597,14 @@ async def test_event_loop_cycle_max_tokens_exception( expected_message = ( "Agent has reached an unrecoverable state due to max_tokens limit. " "For more information see: " - "https://strandsagents.com/latest/user-guide/concepts/agents/agent-loop/#eventloopmaxtokensreachedexception" + "https://strandsagents.com/latest/user-guide/concepts/agents/agent-loop/#maxtokensreachedexception" ) assert str(exc_info.value) == expected_message + # Verify that the message has not been appended to the messages array + assert len(agent.messages) == 1 + assert exc_info.value.incomplete_message not in agent.messages + @patch("strands.event_loop.event_loop.get_tracer") @pytest.mark.asyncio diff --git a/tests_integ/test_max_tokens_reached.py b/tests_integ/test_max_tokens_reached.py index 1bf75f136..519cf62c2 100644 --- a/tests_integ/test_max_tokens_reached.py +++ b/tests_integ/test_max_tokens_reached.py @@ -1,7 +1,8 @@ +import pytest from strands import Agent, tool from strands.models.bedrock import BedrockModel -from strands.types.exceptions import EventLoopException, EventLoopMaxTokensReachedException +from strands.types.exceptions import MaxTokensReachedException @tool @@ -13,7 +14,5 @@ def test_context_window_overflow(): model = BedrockModel(max_tokens=1) agent = Agent(model=model, tools=[story_tool]) - try: + with pytest.raises(MaxTokensReachedException): agent("Tell me a story!") - except EventLoopException as e: - assert isinstance(e.original_exception, EventLoopMaxTokensReachedException) From 36dd0f9304ba0daa4fceffef614ff91400fcb23a Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Thu, 31 Jul 2025 14:53:04 -0400 Subject: [PATCH 04/23] Update tests_integ/test_max_tokens_reached.py Co-authored-by: Nick Clegg --- tests_integ/test_max_tokens_reached.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests_integ/test_max_tokens_reached.py b/tests_integ/test_max_tokens_reached.py index 519cf62c2..1b822dcba 100644 --- a/tests_integ/test_max_tokens_reached.py +++ b/tests_integ/test_max_tokens_reached.py @@ -11,7 +11,7 @@ def story_tool(story: str) -> str: def test_context_window_overflow(): - model = BedrockModel(max_tokens=1) + model = BedrockModel(max_tokens=100) agent = Agent(model=model, tools=[story_tool]) with pytest.raises(MaxTokensReachedException): From e04c73d85d86dde5d9e415ae2ef693aa9a55da56 Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Thu, 31 Jul 2025 14:53:11 -0400 Subject: [PATCH 05/23] Update tests_integ/test_max_tokens_reached.py Co-authored-by: Nick Clegg --- tests_integ/test_max_tokens_reached.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests_integ/test_max_tokens_reached.py b/tests_integ/test_max_tokens_reached.py index 1b822dcba..5f7e5584c 100644 --- a/tests_integ/test_max_tokens_reached.py +++ b/tests_integ/test_max_tokens_reached.py @@ -16,3 +16,5 @@ def test_context_window_overflow(): with pytest.raises(MaxTokensReachedException): agent("Tell me a story!") + + assert len(agent.messages) == 1 From cca2f86a3f7a1d22cfa8cf59ffa0029943a0efa7 Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Thu, 31 Jul 2025 14:57:19 -0400 Subject: [PATCH 06/23] linting --- tests_integ/test_max_tokens_reached.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests_integ/test_max_tokens_reached.py b/tests_integ/test_max_tokens_reached.py index 5f7e5584c..d9c2817b3 100644 --- a/tests_integ/test_max_tokens_reached.py +++ b/tests_integ/test_max_tokens_reached.py @@ -16,5 +16,5 @@ def test_context_window_overflow(): with pytest.raises(MaxTokensReachedException): agent("Tell me a story!") - + assert len(agent.messages) == 1 From 2e2d4df9f6d7d98993f65fb40540663c74f7f0ea Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Mon, 4 Aug 2025 17:47:25 -0400 Subject: [PATCH 07/23] feat: add builtin hook provider to address max tokens reached truncation --- src/strands/agent/agent.py | 18 +++- src/strands/experimental/hooks/__init__.py | 2 + src/strands/experimental/hooks/events.py | 26 +++++ .../experimental/hooks/providers/__init__.py | 0 .../correct_tool_use_hook_provider.py | 95 +++++++++++++++++++ tests/strands/agent/test_agent_hooks.py | 55 ++++++++++- tests_integ/test_max_tokens_reached.py | 18 ++++ 7 files changed, 212 insertions(+), 2 deletions(-) create mode 100644 src/strands/experimental/hooks/providers/__init__.py create mode 100644 src/strands/experimental/hooks/providers/correct_tool_use_hook_provider.py diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index 111509e3a..c86b64ff3 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -20,6 +20,7 @@ from pydantic import BaseModel from ..event_loop.event_loop import event_loop_cycle, run_tool +from ..experimental.hooks.events import EventLoopFailureEvent from ..handlers.callback_handler import PrintingCallbackHandler, null_callback_handler from ..hooks import ( AfterInvocationEvent, @@ -582,7 +583,7 @@ async def _execute_event_loop_cycle(self, invocation_state: dict[str, Any]) -> A ) async for event in events: yield event - + return except ContextWindowOverflowException as e: # Try reducing the context size and retrying self.conversation_manager.reduce_context(self, e=e) @@ -591,6 +592,21 @@ async def _execute_event_loop_cycle(self, invocation_state: dict[str, Any]) -> A if self._session_manager: self._session_manager.sync_agent(self) + # If the events have been handled, attempt to restart the event loop in the now-healthy state + events = self._execute_event_loop_cycle(invocation_state) + async for event in events: + yield event + except Exception as e: + """ + Catch all other exceptions which are unrecoverable without intervention. + Reraise exception if EventLoopFailureEvent.should_continue is false + """ + event_loop_failure_event = EventLoopFailureEvent(agent=self, exception=e) + self.hooks.invoke_callbacks(event_loop_failure_event) + if not event_loop_failure_event.should_continue_loop: + raise + + # If the events have been handled, attempt to restart the event loop in the now-healthy state events = self._execute_event_loop_cycle(invocation_state) async for event in events: yield event diff --git a/src/strands/experimental/hooks/__init__.py b/src/strands/experimental/hooks/__init__.py index 098d4cf0d..384d8a505 100644 --- a/src/strands/experimental/hooks/__init__.py +++ b/src/strands/experimental/hooks/__init__.py @@ -5,6 +5,7 @@ AfterToolInvocationEvent, BeforeModelInvocationEvent, BeforeToolInvocationEvent, + EventLoopFailureEvent, ) __all__ = [ @@ -12,4 +13,5 @@ "AfterToolInvocationEvent", "BeforeModelInvocationEvent", "AfterModelInvocationEvent", + "EventLoopFailureEvent", ] diff --git a/src/strands/experimental/hooks/events.py b/src/strands/experimental/hooks/events.py index d03e65d85..128882821 100644 --- a/src/strands/experimental/hooks/events.py +++ b/src/strands/experimental/hooks/events.py @@ -121,3 +121,29 @@ class ModelStopResponse: def should_reverse_callbacks(self) -> bool: """True to invoke callbacks in reverse order.""" return True + + +@dataclass +class EventLoopFailureEvent(HookEvent): + """Event triggered when the event loop encounters a failure. + + This event is fired when an exception occurs during event loop execution, + allowing hook providers to handle the failure or perform recovery actions. + + Attributes: + exception: The exception that caused the event loop failure. + should_continue_loop: Flag that hooks can set to True to indicate they have + handled the exception and the event loop should continue normally. + + Warning: + Setting should_continue_loop=True without properly addressing the underlying + cause of the exception may result in infinite loops if the same failure + condition persists. Hooks should implement appropriate error handling, + retry limits, or state modifications to prevent recurring failures. + """ + + exception: Exception + should_continue_loop: bool = False + + def _can_write(self, name: str) -> bool: + return name == "should_continue_loop" diff --git a/src/strands/experimental/hooks/providers/__init__.py b/src/strands/experimental/hooks/providers/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/strands/experimental/hooks/providers/correct_tool_use_hook_provider.py b/src/strands/experimental/hooks/providers/correct_tool_use_hook_provider.py new file mode 100644 index 000000000..96020cc58 --- /dev/null +++ b/src/strands/experimental/hooks/providers/correct_tool_use_hook_provider.py @@ -0,0 +1,95 @@ +import logging +from typing import Any + +from src.strands.hooks import MessageAddedEvent +from src.strands.types.tools import ToolUse +from strands.experimental.hooks.events import EventLoopFailureEvent +from strands.hooks import HookProvider, HookRegistry +from strands.types.content import ContentBlock, Message +from strands.types.exceptions import MaxTokensReachedException + +logger = logging.getLogger(__name__) + + +class CorrectToolUseHookProvider(HookProvider): + """Hook provider that handles MaxTokensReachedException by fixing incomplete tool uses. + + This hook provider is triggered when a MaxTokensReachedException occurs during event loop execution. + When the model's response is truncated due to token limits, tool use entries may be incomplete + or missing required fields (name, input, toolUseId). + + The provider fixes these issues by: + + 1. Inspecting each content block in the incomplete message for invalid tool uses + 2. Replacing incomplete tool use blocks with informative text messages + 3. Preserving valid content blocks in the corrected message + 4. Adding the corrected message to the agent's conversation history + 5. Allowing the event loop to continue processing + + If a tool use is invalid for unknown reasons, not due to empty fields, the hook + allows the original exception to propagate to avoid unsafe recovery attempts. + """ + + def register_hooks(self, registry: "HookRegistry", **kwargs: Any) -> None: + """Register hook to handle EventLoopFailureEvent for MaxTokensReachedException.""" + registry.add_callback(EventLoopFailureEvent, self._handle_max_tokens_reached) + + def _handle_max_tokens_reached(self, event: EventLoopFailureEvent) -> None: + """Handle MaxTokensReachedException by cleaning up orphaned tool uses and allowing continuation.""" + if not isinstance(event.exception, MaxTokensReachedException): + return + + logger.info("Handling MaxTokensReachedException - inspecting incomplete message for invalid tool uses") + + incomplete_message: Message = event.exception.incomplete_message + valid_content: list[ContentBlock] = [] + + for i, content in enumerate(incomplete_message["content"]): + tool_use: ToolUse = content.get("toolUse") + if not tool_use: + valid_content.append(content) + logger.debug(f"Content block {i}: Valid non-tool content preserved") + continue + + """ + Ideally this would be future proofed using a pydantic validator. Since ToolUse is not implemented + using pydantic, we inspect each field. + """ + tool_name = tool_use.get("name", "") + tool_input = tool_use.get("input") + tool_use_id = tool_use.get("toolUseId") + + if not (tool_name and tool_input and tool_use_id): + """ + If tool_use does not conform to the expected schema it means the max_tokens issue resulted in it not + being populated it correctly. + + It is safe to drop the content block, but we insert a new one to ensure Agent is aware of failure + on the next iteration. + """ + logger.warning( + f"Invalid tool use found at content block {i}: tool_name='{tool_name}', " + f"Replacing with error message due to max_tokens truncation." + ) + + valid_content.append( + { + "text": f"The selected tool {tool_name}'s tool use was incomplete due " + f"to maximum token limits being reached." + } + ) + else: + # Tool use is invalid for an unknown reason. Cannot safely recover, so allow exception to propagate + logger.debug( + f"Tool use at content block {i} appears complete but is still invalid. " + f"tool_name='{tool_name}', tool_use_id='{tool_use_id}'. " + f"Cannot safely recover - allowing exception to propagate." + ) + return + + valid_message: Message = {"content": valid_content, "role": incomplete_message["role"]} + event.agent.messages.append(valid_message) + event.agent.hooks.invoke_callbacks(MessageAddedEvent(agent=event.agent, message=valid_message)) + event.should_continue_loop = True + + logger.info("MaxTokensReachedException handled successfully - continuing event loop") diff --git a/tests/strands/agent/test_agent_hooks.py b/tests/strands/agent/test_agent_hooks.py index cd89fbc7a..e71c0aa94 100644 --- a/tests/strands/agent/test_agent_hooks.py +++ b/tests/strands/agent/test_agent_hooks.py @@ -1,15 +1,17 @@ -from unittest.mock import ANY, Mock +from unittest.mock import ANY, Mock, patch import pytest from pydantic import BaseModel import strands +from src.strands.types.exceptions import MaxTokensReachedException from strands import Agent from strands.experimental.hooks import ( AfterModelInvocationEvent, AfterToolInvocationEvent, BeforeModelInvocationEvent, BeforeToolInvocationEvent, + EventLoopFailureEvent, ) from strands.hooks import ( AfterInvocationEvent, @@ -35,6 +37,7 @@ def hook_provider(): BeforeModelInvocationEvent, AfterModelInvocationEvent, MessageAddedEvent, + EventLoopFailureEvent, ] ) @@ -292,3 +295,53 @@ async def test_agent_structured_async_output_hooks(agent, hook_provider, user, a assert next(events) == AfterInvocationEvent(agent=agent) assert len(agent.messages) == 1 + + +def test_event_loop_failure_event_exception_rethrown_when_not_handled(agent, hook_provider): + """Test that EventLoopFailureEvent is triggered and exceptions are re-thrown when not handled.""" + + # Mock event_loop_cycle to raise a general exception (not ContextWindowOverflowException) + with patch("strands.agent.agent.event_loop_cycle") as mock_cycle: + mock_cycle.side_effect = MaxTokensReachedException("Event loop failure", {"content": [], "role": "assistant"}) + + with pytest.raises(MaxTokensReachedException): + agent("test message") + length, events = hook_provider.get_events() + failure_events = [event for event in list(events) if isinstance(event, EventLoopFailureEvent)] + + assert len(failure_events) == 1 + assert isinstance(failure_events[0].exception, MaxTokensReachedException) + assert failure_events[0].should_continue_loop is False + + +def test_event_loop_failure_event_exception_handled_by_hook(agent, hook_provider): + """Test that EventLoopFailureEvent allows hooks to handle exceptions and continue execution.""" + + first_call = True + + def hook_callback(event: EventLoopFailureEvent): + nonlocal first_call + # Hook handles the exception by setting should_continue_loop to True + event.should_continue_loop = first_call + first_call = False + + agent.hooks.add_callback(EventLoopFailureEvent, hook_callback) + + # Mock event_loop_cycle to raise a general exception + with patch("strands.agent.agent.event_loop_cycle") as mock_cycle: + mock_cycle.side_effect = MaxTokensReachedException("Event loop failure", {"content": [], "role": "assistant"}) + + # Should NOT raise exception due to hook handling on the first failure + with pytest.raises(MaxTokensReachedException): + agent("test message") + + length, events = hook_provider.get_events() + failure_events = [event for event in list(events) if isinstance(event, EventLoopFailureEvent)] + + assert len(failure_events) == 2 + + assert isinstance(failure_events[0].exception, MaxTokensReachedException) + assert failure_events[0].should_continue_loop is True + + assert isinstance(failure_events[1].exception, MaxTokensReachedException) + assert failure_events[1].should_continue_loop is False diff --git a/tests_integ/test_max_tokens_reached.py b/tests_integ/test_max_tokens_reached.py index d9c2817b3..7c7a48973 100644 --- a/tests_integ/test_max_tokens_reached.py +++ b/tests_integ/test_max_tokens_reached.py @@ -1,12 +1,20 @@ +import logging + import pytest from strands import Agent, tool +from strands.experimental.hooks.providers.correct_tool_use_hook_provider import CorrectToolUseHookProvider from strands.models.bedrock import BedrockModel from strands.types.exceptions import MaxTokensReachedException +logger = logging.getLogger(__name__) + @tool def story_tool(story: str) -> str: + """ + Tool that writes a story that is minimum 50,000 lines long. + """ return story @@ -18,3 +26,13 @@ def test_context_window_overflow(): agent("Tell me a story!") assert len(agent.messages) == 1 + + +def test_max_tokens_reached_with_hook_provider(): + """Test that MaxTokensReachedException can be handled by a hook provider.""" + model = BedrockModel(max_tokens=100) + hook_provider = CorrectToolUseHookProvider() + agent = Agent(model=model, tools=[story_tool], hooks=[hook_provider]) + + # This should NOT raise an exception because the hook handles it + agent("Tell me a story!") From 447d147ee001288dcec224d4e3389b71a7f0dd2c Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Mon, 4 Aug 2025 18:23:40 -0400 Subject: [PATCH 08/23] tests: modify integ test to inspect message history --- .../experimental/hooks/providers/__init__.py | 5 + .../correct_tool_use_hook_provider.py | 40 ++++---- .../experimental/hooks/providers/__init__.py | 1 + .../test_correct_tool_use_hook_provider.py | 99 +++++++++++++++++++ tests_integ/test_max_tokens_reached.py | 13 +++ 5 files changed, 139 insertions(+), 19 deletions(-) create mode 100644 tests/strands/experimental/hooks/providers/__init__.py create mode 100644 tests/strands/experimental/hooks/providers/test_correct_tool_use_hook_provider.py diff --git a/src/strands/experimental/hooks/providers/__init__.py b/src/strands/experimental/hooks/providers/__init__.py index e69de29bb..5b74733b8 100644 --- a/src/strands/experimental/hooks/providers/__init__.py +++ b/src/strands/experimental/hooks/providers/__init__.py @@ -0,0 +1,5 @@ +"""Hook providers for experimental Strands Agents functionality. + +This package contains experimental hook providers that extend the core agent functionality +with additional capabilities. +""" diff --git a/src/strands/experimental/hooks/providers/correct_tool_use_hook_provider.py b/src/strands/experimental/hooks/providers/correct_tool_use_hook_provider.py index 96020cc58..3c9ef0803 100644 --- a/src/strands/experimental/hooks/providers/correct_tool_use_hook_provider.py +++ b/src/strands/experimental/hooks/providers/correct_tool_use_hook_provider.py @@ -1,3 +1,11 @@ +"""Hook provider for correcting incomplete tool uses due to token limits. + +This module provides the CorrectToolUseHookProvider class, which handles scenarios where +the model's response is truncated due to maximum token limits, resulting in incomplete +or malformed tool use entries. The provider automatically corrects these issues to allow +the agent conversation to continue gracefully. +""" + import logging from typing import Any @@ -42,24 +50,25 @@ def _handle_max_tokens_reached(self, event: EventLoopFailureEvent) -> None: logger.info("Handling MaxTokensReachedException - inspecting incomplete message for invalid tool uses") incomplete_message: Message = event.exception.incomplete_message - valid_content: list[ContentBlock] = [] - for i, content in enumerate(incomplete_message["content"]): + if not incomplete_message["content"]: + # Cannot correct invalid content block if content is empty + return + + valid_content: list[ContentBlock] = [] + for content in incomplete_message["content"]: tool_use: ToolUse = content.get("toolUse") if not tool_use: valid_content.append(content) - logger.debug(f"Content block {i}: Valid non-tool content preserved") continue """ Ideally this would be future proofed using a pydantic validator. Since ToolUse is not implemented using pydantic, we inspect each field. """ - tool_name = tool_use.get("name", "") - tool_input = tool_use.get("input") - tool_use_id = tool_use.get("toolUseId") - - if not (tool_name and tool_input and tool_use_id): + # Check if tool use is incomplete (missing or empty required fields) + tool_name = tool_use.get("name") + if not (tool_name and tool_use.get("input") and tool_use.get("toolUseId")): """ If tool_use does not conform to the expected schema it means the max_tokens issue resulted in it not being populated it correctly. @@ -67,29 +76,22 @@ def _handle_max_tokens_reached(self, event: EventLoopFailureEvent) -> None: It is safe to drop the content block, but we insert a new one to ensure Agent is aware of failure on the next iteration. """ + display_name = tool_name if tool_name else "" logger.warning( - f"Invalid tool use found at content block {i}: tool_name='{tool_name}', " - f"Replacing with error message due to max_tokens truncation." + "tool_name=<%s> | replacing with error message due to max_tokens truncation.", display_name ) valid_content.append( { - "text": f"The selected tool {tool_name}'s tool use was incomplete due " + "text": f"The selected tool {display_name}'s tool use was incomplete due " f"to maximum token limits being reached." } ) else: - # Tool use is invalid for an unknown reason. Cannot safely recover, so allow exception to propagate - logger.debug( - f"Tool use at content block {i} appears complete but is still invalid. " - f"tool_name='{tool_name}', tool_use_id='{tool_use_id}'. " - f"Cannot safely recover - allowing exception to propagate." - ) + # ToolUse was invalid for an unknown reason. Cannot correct, return and allow exception to propagate up. return valid_message: Message = {"content": valid_content, "role": incomplete_message["role"]} event.agent.messages.append(valid_message) event.agent.hooks.invoke_callbacks(MessageAddedEvent(agent=event.agent, message=valid_message)) event.should_continue_loop = True - - logger.info("MaxTokensReachedException handled successfully - continuing event loop") diff --git a/tests/strands/experimental/hooks/providers/__init__.py b/tests/strands/experimental/hooks/providers/__init__.py new file mode 100644 index 000000000..8b1378917 --- /dev/null +++ b/tests/strands/experimental/hooks/providers/__init__.py @@ -0,0 +1 @@ + diff --git a/tests/strands/experimental/hooks/providers/test_correct_tool_use_hook_provider.py b/tests/strands/experimental/hooks/providers/test_correct_tool_use_hook_provider.py new file mode 100644 index 000000000..93d672ab2 --- /dev/null +++ b/tests/strands/experimental/hooks/providers/test_correct_tool_use_hook_provider.py @@ -0,0 +1,99 @@ +"""Unit tests for CorrectToolUseHookProvider.""" + +from unittest.mock import Mock + +import pytest + +from strands.experimental.hooks.events import EventLoopFailureEvent +from strands.experimental.hooks.providers.correct_tool_use_hook_provider import CorrectToolUseHookProvider +from strands.hooks import HookRegistry +from strands.types.content import Message +from strands.types.exceptions import MaxTokensReachedException + + +@pytest.fixture +def hook_provider(): + """Create a CorrectToolUseHookProvider instance.""" + return CorrectToolUseHookProvider() + + +@pytest.fixture +def mock_agent(): + """Create a mock agent with messages and hooks.""" + agent = Mock() + agent.messages = [] + agent.hooks = Mock() + return agent + + +@pytest.fixture +def mock_registry(): + """Create a mock hook registry.""" + return Mock(spec=HookRegistry) + + +def test_register_hooks(hook_provider, mock_registry): + """Test that the hook provider registers the correct callback.""" + hook_provider.register_hooks(mock_registry) + + mock_registry.add_callback.assert_called_once_with(EventLoopFailureEvent, hook_provider._handle_max_tokens_reached) + + +def test_handle_non_max_tokens_exception(hook_provider, mock_agent): + """Test that non-MaxTokensReachedException events are ignored.""" + other_exception = ValueError("Some other error") + event = EventLoopFailureEvent(agent=mock_agent, exception=other_exception) + + hook_provider._handle_max_tokens_reached(event) + + # Should not modify the agent or event + assert len(mock_agent.messages) == 0 + assert not event.should_continue_loop + mock_agent.hooks.invoke_callbacks.assert_not_called() + + +@pytest.mark.parametrize( + "incomplete_tool_use,expected_tool_name", + [ + ({"toolUseId": "tool-123", "input": {"param": "value"}}, ""), # Missing name + ({"name": "test_tool", "toolUseId": "tool-123"}, "test_tool"), # Missing input + ({"name": "test_tool", "input": {}, "toolUseId": "tool-123"}, "test_tool"), # Empty input + ({"name": "test_tool", "input": {"param": "value"}}, "test_tool"), # Missing toolUseId + ], +) +def test_handle_max_tokens_with_incomplete_tool_use(hook_provider, mock_agent, incomplete_tool_use, expected_tool_name): + """Test handling various incomplete tool use scenarios.""" + incomplete_message: Message = { + "role": "user", # Test role preservation + "content": [{"text": "I'll use a tool"}, {"toolUse": incomplete_tool_use}], + } + + exception = MaxTokensReachedException("Max tokens reached", incomplete_message) + event = EventLoopFailureEvent(agent=mock_agent, exception=exception) + + hook_provider._handle_max_tokens_reached(event) + + # Should add corrected message with error text and preserve role + assert len(mock_agent.messages) == 1 + added_message = mock_agent.messages[0] + assert added_message["role"] == "user" # Role preserved + assert len(added_message["content"]) == 2 + assert added_message["content"][0]["text"] == "I'll use a tool" + assert f"The selected tool {expected_tool_name}'s tool use was incomplete" in added_message["content"][1]["text"] + assert "maximum token limits being reached" in added_message["content"][1]["text"] + + assert event.should_continue_loop + + +def test_handle_max_tokens_with_no_content(hook_provider, mock_agent): + """Test handling message with no content blocks.""" + incomplete_message: Message = {"role": "assistant", "content": []} + + exception = MaxTokensReachedException("Max tokens reached", incomplete_message) + event = EventLoopFailureEvent(agent=mock_agent, exception=exception) + + hook_provider._handle_max_tokens_reached(event) + + # Should add empty message and continue + assert len(mock_agent.messages) == 0 + assert not event.should_continue_loop diff --git a/tests_integ/test_max_tokens_reached.py b/tests_integ/test_max_tokens_reached.py index 7c7a48973..6bad70636 100644 --- a/tests_integ/test_max_tokens_reached.py +++ b/tests_integ/test_max_tokens_reached.py @@ -36,3 +36,16 @@ def test_max_tokens_reached_with_hook_provider(): # This should NOT raise an exception because the hook handles it agent("Tell me a story!") + + # Validate that at least one message contains the incomplete tool use error message + expected_text = "tool use was incomplete due to maximum token limits being reached" + all_text_content = [ + content_block["text"] + for message in agent.messages + for content_block in message.get("content", []) + if "text" in content_block + ] + + assert any(expected_text in text for text in all_text_content), ( + f"Expected to find message containing '{expected_text}' in agent messages" + ) From 564895d5e04ec6e46ebec04bfa5421fc0fbbcdce Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Mon, 4 Aug 2025 18:31:25 -0400 Subject: [PATCH 09/23] fix: fix linting errors --- tests/strands/agent/test_agent_hooks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/strands/agent/test_agent_hooks.py b/tests/strands/agent/test_agent_hooks.py index e71c0aa94..66eb86808 100644 --- a/tests/strands/agent/test_agent_hooks.py +++ b/tests/strands/agent/test_agent_hooks.py @@ -4,7 +4,6 @@ from pydantic import BaseModel import strands -from src.strands.types.exceptions import MaxTokensReachedException from strands import Agent from strands.experimental.hooks import ( AfterModelInvocationEvent, @@ -20,6 +19,7 @@ MessageAddedEvent, ) from strands.types.content import Messages +from strands.types.exceptions import MaxTokensReachedException from strands.types.tools import ToolResult, ToolUse from tests.fixtures.mock_hook_provider import MockHookProvider from tests.fixtures.mocked_model_provider import MockedModelProvider From 2f118fb03b7faba54850981e99c7bf1a76785bca Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Mon, 4 Aug 2025 18:39:28 -0400 Subject: [PATCH 10/23] fix: linting --- .../hooks/providers/correct_tool_use_hook_provider.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/strands/experimental/hooks/providers/correct_tool_use_hook_provider.py b/src/strands/experimental/hooks/providers/correct_tool_use_hook_provider.py index 3c9ef0803..c8b12d98d 100644 --- a/src/strands/experimental/hooks/providers/correct_tool_use_hook_provider.py +++ b/src/strands/experimental/hooks/providers/correct_tool_use_hook_provider.py @@ -9,12 +9,11 @@ import logging from typing import Any -from src.strands.hooks import MessageAddedEvent -from src.strands.types.tools import ToolUse from strands.experimental.hooks.events import EventLoopFailureEvent -from strands.hooks import HookProvider, HookRegistry +from strands.hooks import HookProvider, HookRegistry, MessageAddedEvent from strands.types.content import ContentBlock, Message from strands.types.exceptions import MaxTokensReachedException +from strands.types.tools import ToolUse logger = logging.getLogger(__name__) @@ -57,7 +56,7 @@ def _handle_max_tokens_reached(self, event: EventLoopFailureEvent) -> None: valid_content: list[ContentBlock] = [] for content in incomplete_message["content"]: - tool_use: ToolUse = content.get("toolUse") + tool_use: ToolUse | None = content.get("toolUse") if not tool_use: valid_content.append(content) continue From e5fc51a432bdc89f112b3b8fc55b3c1e7b4d063a Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Tue, 5 Aug 2025 18:30:08 -0400 Subject: [PATCH 11/23] refactor: switch from hook approach to conversation manager --- src/strands/agent/agent.py | 38 ++-- .../conversation_manager.py | 16 ++ .../null_conversation_manager.py | 15 +- .../sliding_window_conversation_manager.py | 13 +- .../summarizing_conversation_manager.py | 13 +- .../token_limit_recovery.py | 66 ++++++ src/strands/experimental/hooks/__init__.py | 2 - src/strands/experimental/hooks/events.py | 26 --- .../experimental/hooks/providers/__init__.py | 5 - .../correct_tool_use_hook_provider.py | 96 --------- .../agent/conversation_manager/__init__.py | 1 + .../test_token_limit_recovery.py | 200 ++++++++++++++++++ tests/strands/agent/test_agent.py | 68 +++++- tests/strands/agent/test_agent_hooks.py | 55 +---- .../agent/test_conversation_manager.py | 92 +++++++- .../experimental/hooks/providers/__init__.py | 1 - .../test_correct_tool_use_hook_provider.py | 99 --------- tests_integ/test_max_tokens_reached.py | 9 +- 18 files changed, 497 insertions(+), 318 deletions(-) create mode 100644 src/strands/agent/conversation_manager/token_limit_recovery.py delete mode 100644 src/strands/experimental/hooks/providers/__init__.py delete mode 100644 src/strands/experimental/hooks/providers/correct_tool_use_hook_provider.py create mode 100644 tests/strands/agent/conversation_manager/__init__.py create mode 100644 tests/strands/agent/conversation_manager/test_token_limit_recovery.py delete mode 100644 tests/strands/experimental/hooks/providers/__init__.py delete mode 100644 tests/strands/experimental/hooks/providers/test_correct_tool_use_hook_provider.py diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index c86b64ff3..e258cb324 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -20,7 +20,6 @@ from pydantic import BaseModel from ..event_loop.event_loop import event_loop_cycle, run_tool -from ..experimental.hooks.events import EventLoopFailureEvent from ..handlers.callback_handler import PrintingCallbackHandler, null_callback_handler from ..hooks import ( AfterInvocationEvent, @@ -38,7 +37,7 @@ from ..tools.registry import ToolRegistry from ..tools.watcher import ToolWatcher from ..types.content import ContentBlock, Message, Messages -from ..types.exceptions import ContextWindowOverflowException +from ..types.exceptions import ContextWindowOverflowException, MaxTokensReachedException from ..types.tools import ToolResult, ToolUse from ..types.traces import AttributeValue from .agent_result import AgentResult @@ -54,13 +53,14 @@ T = TypeVar("T", bound=BaseModel) -# Sentinel class and object to distinguish between explicit None and default parameter value +# Sentinel classes to distinguish between explicit None and default parameter value class _DefaultCallbackHandlerSentinel: """Sentinel class to distinguish between explicit None and default parameter value.""" pass + _DEFAULT_CALLBACK_HANDLER = _DefaultCallbackHandlerSentinel() _DEFAULT_AGENT_NAME = "Strands Agents" _DEFAULT_AGENT_ID = "default" @@ -247,7 +247,7 @@ def __init__( state: stateful information for the agent. Can be either an AgentState object, or a json serializable dict. Defaults to an empty AgentState object. hooks: hooks to be added to the agent hook registry - Defaults to None. + Defaults to set of if None. session_manager: Manager for handling agent sessions including conversation history and state. If provided, enables session-based persistence and state management. """ @@ -587,29 +587,17 @@ async def _execute_event_loop_cycle(self, invocation_state: dict[str, Any]) -> A except ContextWindowOverflowException as e: # Try reducing the context size and retrying self.conversation_manager.reduce_context(self, e=e) + except MaxTokensReachedException as e: + # Recover conversation state after token limit exceeded, then continue with next cycle + self.conversation_manager.handle_token_limit_reached(self, e=e) - # Sync agent after reduce_context to keep conversation_manager_state up to date in the session - if self._session_manager: - self._session_manager.sync_agent(self) - - # If the events have been handled, attempt to restart the event loop in the now-healthy state - events = self._execute_event_loop_cycle(invocation_state) - async for event in events: - yield event - except Exception as e: - """ - Catch all other exceptions which are unrecoverable without intervention. - Reraise exception if EventLoopFailureEvent.should_continue is false - """ - event_loop_failure_event = EventLoopFailureEvent(agent=self, exception=e) - self.hooks.invoke_callbacks(event_loop_failure_event) - if not event_loop_failure_event.should_continue_loop: - raise + # Sync agent after handling exception to keep conversation_manager_state up to date in the session + if self._session_manager: + self._session_manager.sync_agent(self) - # If the events have been handled, attempt to restart the event loop in the now-healthy state - events = self._execute_event_loop_cycle(invocation_state) - async for event in events: - yield event + events = self._execute_event_loop_cycle(invocation_state) + async for event in events: + yield event def _record_tool_execution( self, diff --git a/src/strands/agent/conversation_manager/conversation_manager.py b/src/strands/agent/conversation_manager/conversation_manager.py index 2c1ee7847..c2899209b 100644 --- a/src/strands/agent/conversation_manager/conversation_manager.py +++ b/src/strands/agent/conversation_manager/conversation_manager.py @@ -4,6 +4,7 @@ from typing import TYPE_CHECKING, Any, Optional from ...types.content import Message +from ...types.exceptions import MaxTokensReachedException if TYPE_CHECKING: from ...agent.agent import Agent @@ -86,3 +87,18 @@ def reduce_context(self, agent: "Agent", e: Optional[Exception] = None, **kwargs **kwargs: Additional keyword arguments for future extensibility. """ pass + + @abstractmethod + def handle_token_limit_reached(self, agent: "Agent", e: MaxTokensReachedException, **kwargs: Any) -> None: + """Called when MaxTokensReachedException is thrown to recover conversation state. + + This method should implement recovery strategies when the token limit is exceeded and the message array + may be in a broken state. It is called outside the event loop to apply default recovery mechanisms. + + Args: + agent: The agent whose conversation state will be recovered. + This list is modified in-place. + e: The MaxTokensReachedException that triggered the recovery. + **kwargs: Additional keyword arguments for future extensibility. + """ + pass diff --git a/src/strands/agent/conversation_manager/null_conversation_manager.py b/src/strands/agent/conversation_manager/null_conversation_manager.py index 5ff6874e5..29fa1c442 100644 --- a/src/strands/agent/conversation_manager/null_conversation_manager.py +++ b/src/strands/agent/conversation_manager/null_conversation_manager.py @@ -5,7 +5,7 @@ if TYPE_CHECKING: from ...agent.agent import Agent -from ...types.exceptions import ContextWindowOverflowException +from ...types.exceptions import ContextWindowOverflowException, MaxTokensReachedException from .conversation_manager import ConversationManager @@ -44,3 +44,16 @@ def reduce_context(self, agent: "Agent", e: Optional[Exception] = None, **kwargs raise e else: raise ContextWindowOverflowException("Context window overflowed!") + + def handle_token_limit_reached(self, agent: "Agent", e: MaxTokensReachedException, **kwargs: Any) -> None: + """Does not handle token limit recovery and raises the exception. + + Args: + agent: The agent whose conversation state will remain unmodified. + e: The MaxTokensReachedException that triggered the recovery. + **kwargs: Additional keyword arguments for future extensibility. + + Raises: + e: The provided exception. + """ + raise e diff --git a/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py b/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py index e082abe8e..f96dbff27 100644 --- a/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py +++ b/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py @@ -7,8 +7,9 @@ from ...agent.agent import Agent from ...types.content import Messages -from ...types.exceptions import ContextWindowOverflowException +from ...types.exceptions import ContextWindowOverflowException, MaxTokensReachedException from .conversation_manager import ConversationManager +from .token_limit_recovery import recover_from_max_tokens_reached logger = logging.getLogger(__name__) @@ -177,3 +178,13 @@ def _find_last_message_with_tool_results(self, messages: Messages) -> Optional[i return idx return None + + def handle_token_limit_reached(self, agent: "Agent", e: MaxTokensReachedException, **kwargs: Any) -> None: + """Apply sliding window strategy for token limit recovery. + + Args: + agent: The agent whose conversation state will be recovered. + e: The MaxTokensReachedException that triggered the recovery. + **kwargs: Additional keyword arguments for future extensibility. + """ + recover_from_max_tokens_reached(agent, e) diff --git a/src/strands/agent/conversation_manager/summarizing_conversation_manager.py b/src/strands/agent/conversation_manager/summarizing_conversation_manager.py index 60e832215..fe0d13fa4 100644 --- a/src/strands/agent/conversation_manager/summarizing_conversation_manager.py +++ b/src/strands/agent/conversation_manager/summarizing_conversation_manager.py @@ -6,8 +6,9 @@ from typing_extensions import override from ...types.content import Message -from ...types.exceptions import ContextWindowOverflowException +from ...types.exceptions import ContextWindowOverflowException, MaxTokensReachedException from .conversation_manager import ConversationManager +from .token_limit_recovery import recover_from_max_tokens_reached if TYPE_CHECKING: from ..agent import Agent @@ -250,3 +251,13 @@ def _adjust_split_point_for_tool_pairs(self, messages: List[Message], split_poin raise ContextWindowOverflowException("Unable to trim conversation context!") return split_point + + def handle_token_limit_reached(self, agent: "Agent", e: MaxTokensReachedException, **kwargs: Any) -> None: + """Apply summarization strategy for token limit recovery. + + Args: + agent: The agent whose conversation state will be recovered. + e: The MaxTokensReachedException that triggered the recovery. + **kwargs: Additional keyword arguments for future extensibility. + """ + recover_from_max_tokens_reached(agent, e) diff --git a/src/strands/agent/conversation_manager/token_limit_recovery.py b/src/strands/agent/conversation_manager/token_limit_recovery.py new file mode 100644 index 000000000..a0935f3a3 --- /dev/null +++ b/src/strands/agent/conversation_manager/token_limit_recovery.py @@ -0,0 +1,66 @@ +"""Shared utility for handling token limit recovery in conversation managers.""" + +import logging +from typing import TYPE_CHECKING + +from ...types.content import ContentBlock, Message +from ...types.exceptions import MaxTokensReachedException +from ...types.tools import ToolUse + +if TYPE_CHECKING: + from ...agent.agent import Agent + +logger = logging.getLogger(__name__) + + +def recover_from_max_tokens_reached(agent: "Agent", exception: MaxTokensReachedException) -> None: + """Handle MaxTokensReachedException by cleaning up orphaned tool uses and adding corrected message. + + This function fixes incomplete tool uses that may occur when the model's response is truncated + due to token limits. It: + + 1. Inspects each content block in the incomplete message for invalid tool uses + 2. Replaces incomplete tool use blocks with informative text messages + 3. Preserves valid content blocks in the corrected message + 4. Adds the corrected message to the agent's conversation history + + Args: + agent: The agent whose conversation will be updated with the corrected message. + exception: The MaxTokensReachedException containing the incomplete message. + """ + logger.info("Handling MaxTokensReachedException - inspecting incomplete message for invalid tool uses") + + incomplete_message: Message = exception.incomplete_message + + if not incomplete_message["content"]: + # Cannot correct invalid content block if content is empty + return + + valid_content: list[ContentBlock] = [] + for content in incomplete_message["content"]: + tool_use: ToolUse | None = content.get("toolUse") + if not tool_use: + valid_content.append(content) + continue + + # Check if tool use is incomplete (missing or empty required fields) + tool_name = tool_use.get("name") + if not (tool_name and tool_use.get("input") and tool_use.get("toolUseId")): + # Tool use is incomplete due to max_tokens truncation + display_name = tool_name if tool_name else "" + logger.warning( + "tool_name=<%s> | replacing with error message due to max_tokens truncation.", display_name + ) + + valid_content.append( + { + "text": f"The selected tool {display_name}'s tool use was incomplete due " + f"to maximum token limits being reached." + } + ) + else: + # ToolUse was invalid for an unknown reason. Cannot correct, return without modifying + return + + valid_message: Message = {"content": valid_content, "role": incomplete_message["role"]} + agent.messages.append(valid_message) diff --git a/src/strands/experimental/hooks/__init__.py b/src/strands/experimental/hooks/__init__.py index 384d8a505..098d4cf0d 100644 --- a/src/strands/experimental/hooks/__init__.py +++ b/src/strands/experimental/hooks/__init__.py @@ -5,7 +5,6 @@ AfterToolInvocationEvent, BeforeModelInvocationEvent, BeforeToolInvocationEvent, - EventLoopFailureEvent, ) __all__ = [ @@ -13,5 +12,4 @@ "AfterToolInvocationEvent", "BeforeModelInvocationEvent", "AfterModelInvocationEvent", - "EventLoopFailureEvent", ] diff --git a/src/strands/experimental/hooks/events.py b/src/strands/experimental/hooks/events.py index 128882821..d03e65d85 100644 --- a/src/strands/experimental/hooks/events.py +++ b/src/strands/experimental/hooks/events.py @@ -121,29 +121,3 @@ class ModelStopResponse: def should_reverse_callbacks(self) -> bool: """True to invoke callbacks in reverse order.""" return True - - -@dataclass -class EventLoopFailureEvent(HookEvent): - """Event triggered when the event loop encounters a failure. - - This event is fired when an exception occurs during event loop execution, - allowing hook providers to handle the failure or perform recovery actions. - - Attributes: - exception: The exception that caused the event loop failure. - should_continue_loop: Flag that hooks can set to True to indicate they have - handled the exception and the event loop should continue normally. - - Warning: - Setting should_continue_loop=True without properly addressing the underlying - cause of the exception may result in infinite loops if the same failure - condition persists. Hooks should implement appropriate error handling, - retry limits, or state modifications to prevent recurring failures. - """ - - exception: Exception - should_continue_loop: bool = False - - def _can_write(self, name: str) -> bool: - return name == "should_continue_loop" diff --git a/src/strands/experimental/hooks/providers/__init__.py b/src/strands/experimental/hooks/providers/__init__.py deleted file mode 100644 index 5b74733b8..000000000 --- a/src/strands/experimental/hooks/providers/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -"""Hook providers for experimental Strands Agents functionality. - -This package contains experimental hook providers that extend the core agent functionality -with additional capabilities. -""" diff --git a/src/strands/experimental/hooks/providers/correct_tool_use_hook_provider.py b/src/strands/experimental/hooks/providers/correct_tool_use_hook_provider.py deleted file mode 100644 index c8b12d98d..000000000 --- a/src/strands/experimental/hooks/providers/correct_tool_use_hook_provider.py +++ /dev/null @@ -1,96 +0,0 @@ -"""Hook provider for correcting incomplete tool uses due to token limits. - -This module provides the CorrectToolUseHookProvider class, which handles scenarios where -the model's response is truncated due to maximum token limits, resulting in incomplete -or malformed tool use entries. The provider automatically corrects these issues to allow -the agent conversation to continue gracefully. -""" - -import logging -from typing import Any - -from strands.experimental.hooks.events import EventLoopFailureEvent -from strands.hooks import HookProvider, HookRegistry, MessageAddedEvent -from strands.types.content import ContentBlock, Message -from strands.types.exceptions import MaxTokensReachedException -from strands.types.tools import ToolUse - -logger = logging.getLogger(__name__) - - -class CorrectToolUseHookProvider(HookProvider): - """Hook provider that handles MaxTokensReachedException by fixing incomplete tool uses. - - This hook provider is triggered when a MaxTokensReachedException occurs during event loop execution. - When the model's response is truncated due to token limits, tool use entries may be incomplete - or missing required fields (name, input, toolUseId). - - The provider fixes these issues by: - - 1. Inspecting each content block in the incomplete message for invalid tool uses - 2. Replacing incomplete tool use blocks with informative text messages - 3. Preserving valid content blocks in the corrected message - 4. Adding the corrected message to the agent's conversation history - 5. Allowing the event loop to continue processing - - If a tool use is invalid for unknown reasons, not due to empty fields, the hook - allows the original exception to propagate to avoid unsafe recovery attempts. - """ - - def register_hooks(self, registry: "HookRegistry", **kwargs: Any) -> None: - """Register hook to handle EventLoopFailureEvent for MaxTokensReachedException.""" - registry.add_callback(EventLoopFailureEvent, self._handle_max_tokens_reached) - - def _handle_max_tokens_reached(self, event: EventLoopFailureEvent) -> None: - """Handle MaxTokensReachedException by cleaning up orphaned tool uses and allowing continuation.""" - if not isinstance(event.exception, MaxTokensReachedException): - return - - logger.info("Handling MaxTokensReachedException - inspecting incomplete message for invalid tool uses") - - incomplete_message: Message = event.exception.incomplete_message - - if not incomplete_message["content"]: - # Cannot correct invalid content block if content is empty - return - - valid_content: list[ContentBlock] = [] - for content in incomplete_message["content"]: - tool_use: ToolUse | None = content.get("toolUse") - if not tool_use: - valid_content.append(content) - continue - - """ - Ideally this would be future proofed using a pydantic validator. Since ToolUse is not implemented - using pydantic, we inspect each field. - """ - # Check if tool use is incomplete (missing or empty required fields) - tool_name = tool_use.get("name") - if not (tool_name and tool_use.get("input") and tool_use.get("toolUseId")): - """ - If tool_use does not conform to the expected schema it means the max_tokens issue resulted in it not - being populated it correctly. - - It is safe to drop the content block, but we insert a new one to ensure Agent is aware of failure - on the next iteration. - """ - display_name = tool_name if tool_name else "" - logger.warning( - "tool_name=<%s> | replacing with error message due to max_tokens truncation.", display_name - ) - - valid_content.append( - { - "text": f"The selected tool {display_name}'s tool use was incomplete due " - f"to maximum token limits being reached." - } - ) - else: - # ToolUse was invalid for an unknown reason. Cannot correct, return and allow exception to propagate up. - return - - valid_message: Message = {"content": valid_content, "role": incomplete_message["role"]} - event.agent.messages.append(valid_message) - event.agent.hooks.invoke_callbacks(MessageAddedEvent(agent=event.agent, message=valid_message)) - event.should_continue_loop = True diff --git a/tests/strands/agent/conversation_manager/__init__.py b/tests/strands/agent/conversation_manager/__init__.py new file mode 100644 index 000000000..d5ee2d119 --- /dev/null +++ b/tests/strands/agent/conversation_manager/__init__.py @@ -0,0 +1 @@ +# Test package for conversation manager diff --git a/tests/strands/agent/conversation_manager/test_token_limit_recovery.py b/tests/strands/agent/conversation_manager/test_token_limit_recovery.py new file mode 100644 index 000000000..8d1655c45 --- /dev/null +++ b/tests/strands/agent/conversation_manager/test_token_limit_recovery.py @@ -0,0 +1,200 @@ +"""Tests for token limit recovery utility.""" + +import pytest + +from strands.agent.agent import Agent +from strands.agent.conversation_manager.token_limit_recovery import recover_from_max_tokens_reached +from strands.types.content import Message +from strands.types.exceptions import MaxTokensReachedException + + +def test_recover_from_max_tokens_reached_with_incomplete_tool_use(): + """Test recovery when incomplete tool use is present in the message.""" + agent = Agent() + initial_message_count = len(agent.messages) + + incomplete_message: Message = { + "role": "assistant", + "content": [ + {"text": "I'll help you with that."}, + {"toolUse": {"name": "calculator", "input": {}, "toolUseId": ""}}, # Missing toolUseId + ] + } + + exception = MaxTokensReachedException( + message="Token limit reached", + incomplete_message=incomplete_message + ) + + recover_from_max_tokens_reached(agent, exception) + + # Should add one corrected message + assert len(agent.messages) == initial_message_count + 1 + + # Check the corrected message content + corrected_message = agent.messages[-1] + assert corrected_message["role"] == "assistant" + assert len(corrected_message["content"]) == 2 + + # First content block should be preserved + assert corrected_message["content"][0] == {"text": "I'll help you with that."} + + # Second content block should be replaced with error message + assert "text" in corrected_message["content"][1] + assert "calculator" in corrected_message["content"][1]["text"] + assert "incomplete due to maximum token limits" in corrected_message["content"][1]["text"] + + +def test_recover_from_max_tokens_reached_with_unknown_tool_name(): + """Test recovery when tool use has no name.""" + agent = Agent() + initial_message_count = len(agent.messages) + + incomplete_message: Message = { + "role": "assistant", + "content": [ + {"toolUse": {"name": "", "input": {}, "toolUseId": "123"}}, # Missing name + ] + } + + exception = MaxTokensReachedException( + message="Token limit reached", + incomplete_message=incomplete_message + ) + + recover_from_max_tokens_reached(agent, exception) + + # Should add one corrected message + assert len(agent.messages) == initial_message_count + 1 + + # Check the corrected message content + corrected_message = agent.messages[-1] + assert corrected_message["role"] == "assistant" + assert len(corrected_message["content"]) == 1 + + # Content should be replaced with error message using + assert "text" in corrected_message["content"][0] + assert "" in corrected_message["content"][0]["text"] + assert "incomplete due to maximum token limits" in corrected_message["content"][0]["text"] + + +def test_recover_from_max_tokens_reached_with_valid_tool_use(): + """Test that valid tool uses are not modified and function returns early.""" + agent = Agent() + initial_message_count = len(agent.messages) + + incomplete_message: Message = { + "role": "assistant", + "content": [ + {"text": "I'll help you with that."}, + {"toolUse": {"name": "calculator", "input": {"expression": "2+2"}, "toolUseId": "123"}}, # Valid + ] + } + + exception = MaxTokensReachedException( + message="Token limit reached", + incomplete_message=incomplete_message + ) + + recover_from_max_tokens_reached(agent, exception) + + # Should not add any message since tool use was valid + assert len(agent.messages) == initial_message_count + + +def test_recover_from_max_tokens_reached_with_empty_content(): + """Test that empty content is handled gracefully.""" + agent = Agent() + initial_message_count = len(agent.messages) + + incomplete_message: Message = { + "role": "assistant", + "content": [] + } + + exception = MaxTokensReachedException( + message="Token limit reached", + incomplete_message=incomplete_message + ) + + recover_from_max_tokens_reached(agent, exception) + + # Should not add any message since content is empty + assert len(agent.messages) == initial_message_count + + +def test_recover_from_max_tokens_reached_with_mixed_content(): + """Test recovery with mix of valid content and incomplete tool use.""" + agent = Agent() + initial_message_count = len(agent.messages) + + incomplete_message: Message = { + "role": "assistant", + "content": [ + {"text": "Let me calculate this for you."}, + {"toolUse": {"name": "calculator", "input": {}, "toolUseId": ""}}, # Incomplete + {"text": "And then I'll explain the result."}, + ] + } + + exception = MaxTokensReachedException( + message="Token limit reached", + incomplete_message=incomplete_message + ) + + recover_from_max_tokens_reached(agent, exception) + + # Should add one corrected message + assert len(agent.messages) == initial_message_count + 1 + + # Check the corrected message content + corrected_message = agent.messages[-1] + assert corrected_message["role"] == "assistant" + assert len(corrected_message["content"]) == 3 + + # First and third content blocks should be preserved + assert corrected_message["content"][0] == {"text": "Let me calculate this for you."} + assert corrected_message["content"][2] == {"text": "And then I'll explain the result."} + + # Second content block should be replaced with error message + assert "text" in corrected_message["content"][1] + assert "calculator" in corrected_message["content"][1]["text"] + assert "incomplete due to maximum token limits" in corrected_message["content"][1]["text"] + + +def test_recover_from_max_tokens_reached_preserves_non_tool_content(): + """Test that non-tool content is preserved as-is.""" + agent = Agent() + initial_message_count = len(agent.messages) + + incomplete_message: Message = { + "role": "assistant", + "content": [ + {"text": "Here's some text."}, + {"image": {"format": "png", "source": {"bytes": "fake_image_data"}}}, + {"toolUse": {"name": "", "input": {}, "toolUseId": "123"}}, # Incomplete + ] + } + + exception = MaxTokensReachedException( + message="Token limit reached", + incomplete_message=incomplete_message + ) + + recover_from_max_tokens_reached(agent, exception) + + # Should add one corrected message + assert len(agent.messages) == initial_message_count + 1 + + # Check the corrected message content + corrected_message = agent.messages[-1] + assert corrected_message["role"] == "assistant" + assert len(corrected_message["content"]) == 3 + + # First two content blocks should be preserved exactly + assert corrected_message["content"][0] == {"text": "Here's some text."} + assert corrected_message["content"][1] == {"image": {"format": "png", "source": {"bytes": "fake_image_data"}}} + + # Third content block should be replaced with error message + assert "text" in corrected_message["content"][2] + assert "" in corrected_message["content"][2]["text"] diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index 4e310dace..9dd802f4e 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -19,7 +19,7 @@ from strands.models.bedrock import DEFAULT_BEDROCK_MODEL_ID, BedrockModel from strands.session.repository_session_manager import RepositorySessionManager from strands.types.content import Messages -from strands.types.exceptions import ContextWindowOverflowException, EventLoopException +from strands.types.exceptions import ContextWindowOverflowException, EventLoopException, MaxTokensReachedException from strands.types.session import Session, SessionAgent, SessionMessage, SessionType from tests.fixtures.mock_session_repository import MockedSessionRepository from tests.fixtures.mocked_model_provider import MockedModelProvider @@ -547,6 +547,72 @@ def test_agent__call__tool_truncation_doesnt_infinite_loop(mock_model, agent): agent("Test!") +def test_agent__call__max_tokens_reached_triggers_conversation_manager_recovery(mock_model, agent, agenerator): + """Test that MaxTokensReachedException triggers conversation manager handle_token_limit_reached.""" + conversation_manager_spy = unittest.mock.Mock(wraps=agent.conversation_manager) + agent.conversation_manager = conversation_manager_spy + + incomplete_message = { + "role": "assistant", + "content": [ + {"text": "I'll help you with that."}, + {"toolUse": {"name": "calculator", "input": {}, "toolUseId": ""}}, # Missing toolUseId + ] + } + + mock_model.mock_stream.side_effect = [ + MaxTokensReachedException( + message="Token limit reached", + incomplete_message=incomplete_message + ), + agenerator( + [ + {"contentBlockStart": {"start": {}}}, + {"contentBlockDelta": {"delta": {"text": "Recovered response"}}}, + {"contentBlockStop": {}}, + {"messageStop": {"stopReason": "end_turn"}}, + ] + ), + ] + + result = agent("Test message") + + # Verify handle_token_limit_reached was called + assert conversation_manager_spy.handle_token_limit_reached.call_count == 1 + + # Verify the call was made with the correct exception + call_args = conversation_manager_spy.handle_token_limit_reached.call_args + args, kwargs = call_args + assert len(args) >= 2 # Should have at least agent and exception + assert isinstance(args[1], MaxTokensReachedException) # Second argument should be the exception + + # Verify apply_management was also called + assert conversation_manager_spy.apply_management.call_count > 0 + + # Verify the agent continued and produced a result + assert result is not None + + +def test_agent__call__max_tokens_reached_with_null_conversation_manager_raises_exception(mock_model, agent): + """Test that MaxTokensReachedException with NullConversationManager raises the exception.""" + agent.conversation_manager = NullConversationManager() + + incomplete_message = { + "role": "assistant", + "content": [ + {"toolUse": {"name": "calculator", "input": {}, "toolUseId": ""}}, # Missing toolUseId + ] + } + + mock_model.mock_stream.side_effect = MaxTokensReachedException( + message="Token limit reached", + incomplete_message=incomplete_message + ) + + with pytest.raises(MaxTokensReachedException): + agent("Test!") + + def test_agent__call__retry_with_overwritten_tool(mock_model, agent, tool, agenerator): conversation_manager_spy = unittest.mock.Mock(wraps=agent.conversation_manager) agent.conversation_manager = conversation_manager_spy diff --git a/tests/strands/agent/test_agent_hooks.py b/tests/strands/agent/test_agent_hooks.py index 66eb86808..cd89fbc7a 100644 --- a/tests/strands/agent/test_agent_hooks.py +++ b/tests/strands/agent/test_agent_hooks.py @@ -1,4 +1,4 @@ -from unittest.mock import ANY, Mock, patch +from unittest.mock import ANY, Mock import pytest from pydantic import BaseModel @@ -10,7 +10,6 @@ AfterToolInvocationEvent, BeforeModelInvocationEvent, BeforeToolInvocationEvent, - EventLoopFailureEvent, ) from strands.hooks import ( AfterInvocationEvent, @@ -19,7 +18,6 @@ MessageAddedEvent, ) from strands.types.content import Messages -from strands.types.exceptions import MaxTokensReachedException from strands.types.tools import ToolResult, ToolUse from tests.fixtures.mock_hook_provider import MockHookProvider from tests.fixtures.mocked_model_provider import MockedModelProvider @@ -37,7 +35,6 @@ def hook_provider(): BeforeModelInvocationEvent, AfterModelInvocationEvent, MessageAddedEvent, - EventLoopFailureEvent, ] ) @@ -295,53 +292,3 @@ async def test_agent_structured_async_output_hooks(agent, hook_provider, user, a assert next(events) == AfterInvocationEvent(agent=agent) assert len(agent.messages) == 1 - - -def test_event_loop_failure_event_exception_rethrown_when_not_handled(agent, hook_provider): - """Test that EventLoopFailureEvent is triggered and exceptions are re-thrown when not handled.""" - - # Mock event_loop_cycle to raise a general exception (not ContextWindowOverflowException) - with patch("strands.agent.agent.event_loop_cycle") as mock_cycle: - mock_cycle.side_effect = MaxTokensReachedException("Event loop failure", {"content": [], "role": "assistant"}) - - with pytest.raises(MaxTokensReachedException): - agent("test message") - length, events = hook_provider.get_events() - failure_events = [event for event in list(events) if isinstance(event, EventLoopFailureEvent)] - - assert len(failure_events) == 1 - assert isinstance(failure_events[0].exception, MaxTokensReachedException) - assert failure_events[0].should_continue_loop is False - - -def test_event_loop_failure_event_exception_handled_by_hook(agent, hook_provider): - """Test that EventLoopFailureEvent allows hooks to handle exceptions and continue execution.""" - - first_call = True - - def hook_callback(event: EventLoopFailureEvent): - nonlocal first_call - # Hook handles the exception by setting should_continue_loop to True - event.should_continue_loop = first_call - first_call = False - - agent.hooks.add_callback(EventLoopFailureEvent, hook_callback) - - # Mock event_loop_cycle to raise a general exception - with patch("strands.agent.agent.event_loop_cycle") as mock_cycle: - mock_cycle.side_effect = MaxTokensReachedException("Event loop failure", {"content": [], "role": "assistant"}) - - # Should NOT raise exception due to hook handling on the first failure - with pytest.raises(MaxTokensReachedException): - agent("test message") - - length, events = hook_provider.get_events() - failure_events = [event for event in list(events) if isinstance(event, EventLoopFailureEvent)] - - assert len(failure_events) == 2 - - assert isinstance(failure_events[0].exception, MaxTokensReachedException) - assert failure_events[0].should_continue_loop is True - - assert isinstance(failure_events[1].exception, MaxTokensReachedException) - assert failure_events[1].should_continue_loop is False diff --git a/tests/strands/agent/test_conversation_manager.py b/tests/strands/agent/test_conversation_manager.py index 77d7dcce8..e3452824e 100644 --- a/tests/strands/agent/test_conversation_manager.py +++ b/tests/strands/agent/test_conversation_manager.py @@ -3,7 +3,9 @@ from strands.agent.agent import Agent from strands.agent.conversation_manager.null_conversation_manager import NullConversationManager from strands.agent.conversation_manager.sliding_window_conversation_manager import SlidingWindowConversationManager -from strands.types.exceptions import ContextWindowOverflowException +from strands.types.content import Message +from strands.types.exceptions import ContextWindowOverflowException, MaxTokensReachedException, MaxTokensReachedException +from strands.types.content import Message @pytest.fixture @@ -204,6 +206,44 @@ def test_sliding_window_conversation_manager_with_tool_results_truncated(): assert messages == expected_messages +def test_sliding_window_conversation_manager_handle_token_limit_reached(): + """Test that SlidingWindowConversationManager handles token limit recovery.""" + manager = SlidingWindowConversationManager() + test_agent = Agent() + initial_message_count = len(test_agent.messages) + + incomplete_message: Message = { + "role": "assistant", + "content": [ + {"text": "I'll help you with that."}, + {"toolUse": {"name": "calculator", "input": {}, "toolUseId": ""}}, # Missing toolUseId + ] + } + + test_exception = MaxTokensReachedException( + message="Token limit reached", + incomplete_message=incomplete_message + ) + + manager.handle_token_limit_reached(test_agent, test_exception) + + # Should add one corrected message + assert len(test_agent.messages) == initial_message_count + 1 + + # Check the corrected message content + corrected_message = test_agent.messages[-1] + assert corrected_message["role"] == "assistant" + assert len(corrected_message["content"]) == 2 + + # First content block should be preserved + assert corrected_message["content"][0] == {"text": "I'll help you with that."} + + # Second content block should be replaced with error message + assert "text" in corrected_message["content"][1] + assert "calculator" in corrected_message["content"][1]["text"] + assert "incomplete due to maximum token limits" in corrected_message["content"][1]["text"] + + def test_null_conversation_manager_reduce_context_raises_context_window_overflow_exception(): """Test that NullConversationManager doesn't modify messages.""" manager = NullConversationManager() @@ -246,3 +286,53 @@ def test_null_conversation_does_not_restore_with_incorrect_state(): with pytest.raises(ValueError): manager.restore_from_session({}) + + +def test_summarizing_conversation_manager_handle_token_limit_reached(): + """Test that SummarizingConversationManager handles token limit recovery.""" + from strands.agent.conversation_manager.summarizing_conversation_manager import SummarizingConversationManager + + manager = SummarizingConversationManager() + test_agent = Agent() + initial_message_count = len(test_agent.messages) + + incomplete_message: Message = { + "role": "assistant", + "content": [ + {"toolUse": {"name": "", "input": {}, "toolUseId": "123"}}, # Missing name + ] + } + + test_exception = MaxTokensReachedException( + message="Token limit reached", + incomplete_message=incomplete_message + ) + + manager.handle_token_limit_reached(test_agent, test_exception) + + # Should add one corrected message + assert len(test_agent.messages) == initial_message_count + 1 + + # Check the corrected message content + corrected_message = test_agent.messages[-1] + assert corrected_message["role"] == "assistant" + assert len(corrected_message["content"]) == 1 + + # Content should be replaced with error message using + assert "text" in corrected_message["content"][0] + assert "" in corrected_message["content"][0]["text"] + assert "incomplete due to maximum token limits" in corrected_message["content"][0]["text"] + + +def test_null_conversation_manager_handle_token_limit_reached_raises_exception(): + """Test that NullConversationManager raises the provided exception.""" + manager = NullConversationManager() + test_agent = Agent() + test_message: Message = { + "role": "assistant", + "content": [{"text": "Hello"}], + } + test_exception = MaxTokensReachedException(message="test", incomplete_message=test_message) + + with pytest.raises(MaxTokensReachedException): + manager.handle_token_limit_reached(test_agent, test_exception) diff --git a/tests/strands/experimental/hooks/providers/__init__.py b/tests/strands/experimental/hooks/providers/__init__.py deleted file mode 100644 index 8b1378917..000000000 --- a/tests/strands/experimental/hooks/providers/__init__.py +++ /dev/null @@ -1 +0,0 @@ - diff --git a/tests/strands/experimental/hooks/providers/test_correct_tool_use_hook_provider.py b/tests/strands/experimental/hooks/providers/test_correct_tool_use_hook_provider.py deleted file mode 100644 index 93d672ab2..000000000 --- a/tests/strands/experimental/hooks/providers/test_correct_tool_use_hook_provider.py +++ /dev/null @@ -1,99 +0,0 @@ -"""Unit tests for CorrectToolUseHookProvider.""" - -from unittest.mock import Mock - -import pytest - -from strands.experimental.hooks.events import EventLoopFailureEvent -from strands.experimental.hooks.providers.correct_tool_use_hook_provider import CorrectToolUseHookProvider -from strands.hooks import HookRegistry -from strands.types.content import Message -from strands.types.exceptions import MaxTokensReachedException - - -@pytest.fixture -def hook_provider(): - """Create a CorrectToolUseHookProvider instance.""" - return CorrectToolUseHookProvider() - - -@pytest.fixture -def mock_agent(): - """Create a mock agent with messages and hooks.""" - agent = Mock() - agent.messages = [] - agent.hooks = Mock() - return agent - - -@pytest.fixture -def mock_registry(): - """Create a mock hook registry.""" - return Mock(spec=HookRegistry) - - -def test_register_hooks(hook_provider, mock_registry): - """Test that the hook provider registers the correct callback.""" - hook_provider.register_hooks(mock_registry) - - mock_registry.add_callback.assert_called_once_with(EventLoopFailureEvent, hook_provider._handle_max_tokens_reached) - - -def test_handle_non_max_tokens_exception(hook_provider, mock_agent): - """Test that non-MaxTokensReachedException events are ignored.""" - other_exception = ValueError("Some other error") - event = EventLoopFailureEvent(agent=mock_agent, exception=other_exception) - - hook_provider._handle_max_tokens_reached(event) - - # Should not modify the agent or event - assert len(mock_agent.messages) == 0 - assert not event.should_continue_loop - mock_agent.hooks.invoke_callbacks.assert_not_called() - - -@pytest.mark.parametrize( - "incomplete_tool_use,expected_tool_name", - [ - ({"toolUseId": "tool-123", "input": {"param": "value"}}, ""), # Missing name - ({"name": "test_tool", "toolUseId": "tool-123"}, "test_tool"), # Missing input - ({"name": "test_tool", "input": {}, "toolUseId": "tool-123"}, "test_tool"), # Empty input - ({"name": "test_tool", "input": {"param": "value"}}, "test_tool"), # Missing toolUseId - ], -) -def test_handle_max_tokens_with_incomplete_tool_use(hook_provider, mock_agent, incomplete_tool_use, expected_tool_name): - """Test handling various incomplete tool use scenarios.""" - incomplete_message: Message = { - "role": "user", # Test role preservation - "content": [{"text": "I'll use a tool"}, {"toolUse": incomplete_tool_use}], - } - - exception = MaxTokensReachedException("Max tokens reached", incomplete_message) - event = EventLoopFailureEvent(agent=mock_agent, exception=exception) - - hook_provider._handle_max_tokens_reached(event) - - # Should add corrected message with error text and preserve role - assert len(mock_agent.messages) == 1 - added_message = mock_agent.messages[0] - assert added_message["role"] == "user" # Role preserved - assert len(added_message["content"]) == 2 - assert added_message["content"][0]["text"] == "I'll use a tool" - assert f"The selected tool {expected_tool_name}'s tool use was incomplete" in added_message["content"][1]["text"] - assert "maximum token limits being reached" in added_message["content"][1]["text"] - - assert event.should_continue_loop - - -def test_handle_max_tokens_with_no_content(hook_provider, mock_agent): - """Test handling message with no content blocks.""" - incomplete_message: Message = {"role": "assistant", "content": []} - - exception = MaxTokensReachedException("Max tokens reached", incomplete_message) - event = EventLoopFailureEvent(agent=mock_agent, exception=exception) - - hook_provider._handle_max_tokens_reached(event) - - # Should add empty message and continue - assert len(mock_agent.messages) == 0 - assert not event.should_continue_loop diff --git a/tests_integ/test_max_tokens_reached.py b/tests_integ/test_max_tokens_reached.py index 6bad70636..d50452801 100644 --- a/tests_integ/test_max_tokens_reached.py +++ b/tests_integ/test_max_tokens_reached.py @@ -3,7 +3,7 @@ import pytest from strands import Agent, tool -from strands.experimental.hooks.providers.correct_tool_use_hook_provider import CorrectToolUseHookProvider +from strands.agent import NullConversationManager from strands.models.bedrock import BedrockModel from strands.types.exceptions import MaxTokensReachedException @@ -18,9 +18,9 @@ def story_tool(story: str) -> str: return story -def test_context_window_overflow(): +def test_max_tokens_reached(): model = BedrockModel(max_tokens=100) - agent = Agent(model=model, tools=[story_tool]) + agent = Agent(model=model, tools=[story_tool], conversation_manager=NullConversationManager()) with pytest.raises(MaxTokensReachedException): agent("Tell me a story!") @@ -31,8 +31,7 @@ def test_context_window_overflow(): def test_max_tokens_reached_with_hook_provider(): """Test that MaxTokensReachedException can be handled by a hook provider.""" model = BedrockModel(max_tokens=100) - hook_provider = CorrectToolUseHookProvider() - agent = Agent(model=model, tools=[story_tool], hooks=[hook_provider]) + agent = Agent(model=model, tools=[story_tool]) # Defaults to include SlidingWindowConversationManager # This should NOT raise an exception because the hook handles it agent("Tell me a story!") From 5906fc2f8d405c2e1326d9b603707f75123285e7 Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Tue, 5 Aug 2025 18:33:36 -0400 Subject: [PATCH 12/23] linting --- src/strands/agent/agent.py | 5 +- .../token_limit_recovery.py | 10 +- .../test_token_limit_recovery.py | 116 +++++++----------- tests/strands/agent/test_agent.py | 18 ++- .../agent/test_conversation_manager.py | 48 ++++---- 5 files changed, 82 insertions(+), 115 deletions(-) diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index e258cb324..e749183fc 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -53,14 +53,13 @@ T = TypeVar("T", bound=BaseModel) -# Sentinel classes to distinguish between explicit None and default parameter value +# Sentinel class and object to distinguish between explicit None and default parameter value class _DefaultCallbackHandlerSentinel: """Sentinel class to distinguish between explicit None and default parameter value.""" pass - _DEFAULT_CALLBACK_HANDLER = _DefaultCallbackHandlerSentinel() _DEFAULT_AGENT_NAME = "Strands Agents" _DEFAULT_AGENT_ID = "default" @@ -247,7 +246,7 @@ def __init__( state: stateful information for the agent. Can be either an AgentState object, or a json serializable dict. Defaults to an empty AgentState object. hooks: hooks to be added to the agent hook registry - Defaults to set of if None. + Defaults to None. session_manager: Manager for handling agent sessions including conversation history and state. If provided, enables session-based persistence and state management. """ diff --git a/src/strands/agent/conversation_manager/token_limit_recovery.py b/src/strands/agent/conversation_manager/token_limit_recovery.py index a0935f3a3..ceb32c735 100644 --- a/src/strands/agent/conversation_manager/token_limit_recovery.py +++ b/src/strands/agent/conversation_manager/token_limit_recovery.py @@ -15,15 +15,15 @@ def recover_from_max_tokens_reached(agent: "Agent", exception: MaxTokensReachedException) -> None: """Handle MaxTokensReachedException by cleaning up orphaned tool uses and adding corrected message. - + This function fixes incomplete tool uses that may occur when the model's response is truncated due to token limits. It: - + 1. Inspects each content block in the incomplete message for invalid tool uses 2. Replaces incomplete tool use blocks with informative text messages 3. Preserves valid content blocks in the corrected message 4. Adds the corrected message to the agent's conversation history - + Args: agent: The agent whose conversation will be updated with the corrected message. exception: The MaxTokensReachedException containing the incomplete message. @@ -48,9 +48,7 @@ def recover_from_max_tokens_reached(agent: "Agent", exception: MaxTokensReachedE if not (tool_name and tool_use.get("input") and tool_use.get("toolUseId")): # Tool use is incomplete due to max_tokens truncation display_name = tool_name if tool_name else "" - logger.warning( - "tool_name=<%s> | replacing with error message due to max_tokens truncation.", display_name - ) + logger.warning("tool_name=<%s> | replacing with error message due to max_tokens truncation.", display_name) valid_content.append( { diff --git a/tests/strands/agent/conversation_manager/test_token_limit_recovery.py b/tests/strands/agent/conversation_manager/test_token_limit_recovery.py index 8d1655c45..9ae6c8722 100644 --- a/tests/strands/agent/conversation_manager/test_token_limit_recovery.py +++ b/tests/strands/agent/conversation_manager/test_token_limit_recovery.py @@ -1,6 +1,5 @@ """Tests for token limit recovery utility.""" -import pytest from strands.agent.agent import Agent from strands.agent.conversation_manager.token_limit_recovery import recover_from_max_tokens_reached @@ -12,33 +11,30 @@ def test_recover_from_max_tokens_reached_with_incomplete_tool_use(): """Test recovery when incomplete tool use is present in the message.""" agent = Agent() initial_message_count = len(agent.messages) - + incomplete_message: Message = { "role": "assistant", "content": [ {"text": "I'll help you with that."}, {"toolUse": {"name": "calculator", "input": {}, "toolUseId": ""}}, # Missing toolUseId - ] + ], } - - exception = MaxTokensReachedException( - message="Token limit reached", - incomplete_message=incomplete_message - ) - + + exception = MaxTokensReachedException(message="Token limit reached", incomplete_message=incomplete_message) + recover_from_max_tokens_reached(agent, exception) - + # Should add one corrected message assert len(agent.messages) == initial_message_count + 1 - + # Check the corrected message content corrected_message = agent.messages[-1] assert corrected_message["role"] == "assistant" assert len(corrected_message["content"]) == 2 - + # First content block should be preserved assert corrected_message["content"][0] == {"text": "I'll help you with that."} - + # Second content block should be replaced with error message assert "text" in corrected_message["content"][1] assert "calculator" in corrected_message["content"][1]["text"] @@ -49,29 +45,26 @@ def test_recover_from_max_tokens_reached_with_unknown_tool_name(): """Test recovery when tool use has no name.""" agent = Agent() initial_message_count = len(agent.messages) - + incomplete_message: Message = { "role": "assistant", "content": [ {"toolUse": {"name": "", "input": {}, "toolUseId": "123"}}, # Missing name - ] + ], } - - exception = MaxTokensReachedException( - message="Token limit reached", - incomplete_message=incomplete_message - ) - + + exception = MaxTokensReachedException(message="Token limit reached", incomplete_message=incomplete_message) + recover_from_max_tokens_reached(agent, exception) - + # Should add one corrected message assert len(agent.messages) == initial_message_count + 1 - + # Check the corrected message content corrected_message = agent.messages[-1] assert corrected_message["role"] == "assistant" assert len(corrected_message["content"]) == 1 - + # Content should be replaced with error message using assert "text" in corrected_message["content"][0] assert "" in corrected_message["content"][0]["text"] @@ -82,22 +75,19 @@ def test_recover_from_max_tokens_reached_with_valid_tool_use(): """Test that valid tool uses are not modified and function returns early.""" agent = Agent() initial_message_count = len(agent.messages) - + incomplete_message: Message = { "role": "assistant", "content": [ {"text": "I'll help you with that."}, {"toolUse": {"name": "calculator", "input": {"expression": "2+2"}, "toolUseId": "123"}}, # Valid - ] + ], } - - exception = MaxTokensReachedException( - message="Token limit reached", - incomplete_message=incomplete_message - ) - + + exception = MaxTokensReachedException(message="Token limit reached", incomplete_message=incomplete_message) + recover_from_max_tokens_reached(agent, exception) - + # Should not add any message since tool use was valid assert len(agent.messages) == initial_message_count @@ -106,19 +96,13 @@ def test_recover_from_max_tokens_reached_with_empty_content(): """Test that empty content is handled gracefully.""" agent = Agent() initial_message_count = len(agent.messages) - - incomplete_message: Message = { - "role": "assistant", - "content": [] - } - - exception = MaxTokensReachedException( - message="Token limit reached", - incomplete_message=incomplete_message - ) - + + incomplete_message: Message = {"role": "assistant", "content": []} + + exception = MaxTokensReachedException(message="Token limit reached", incomplete_message=incomplete_message) + recover_from_max_tokens_reached(agent, exception) - + # Should not add any message since content is empty assert len(agent.messages) == initial_message_count @@ -127,35 +111,32 @@ def test_recover_from_max_tokens_reached_with_mixed_content(): """Test recovery with mix of valid content and incomplete tool use.""" agent = Agent() initial_message_count = len(agent.messages) - + incomplete_message: Message = { "role": "assistant", "content": [ {"text": "Let me calculate this for you."}, {"toolUse": {"name": "calculator", "input": {}, "toolUseId": ""}}, # Incomplete {"text": "And then I'll explain the result."}, - ] + ], } - - exception = MaxTokensReachedException( - message="Token limit reached", - incomplete_message=incomplete_message - ) - + + exception = MaxTokensReachedException(message="Token limit reached", incomplete_message=incomplete_message) + recover_from_max_tokens_reached(agent, exception) - + # Should add one corrected message assert len(agent.messages) == initial_message_count + 1 - + # Check the corrected message content corrected_message = agent.messages[-1] assert corrected_message["role"] == "assistant" assert len(corrected_message["content"]) == 3 - + # First and third content blocks should be preserved assert corrected_message["content"][0] == {"text": "Let me calculate this for you."} assert corrected_message["content"][2] == {"text": "And then I'll explain the result."} - + # Second content block should be replaced with error message assert "text" in corrected_message["content"][1] assert "calculator" in corrected_message["content"][1]["text"] @@ -166,35 +147,32 @@ def test_recover_from_max_tokens_reached_preserves_non_tool_content(): """Test that non-tool content is preserved as-is.""" agent = Agent() initial_message_count = len(agent.messages) - + incomplete_message: Message = { "role": "assistant", "content": [ {"text": "Here's some text."}, {"image": {"format": "png", "source": {"bytes": "fake_image_data"}}}, {"toolUse": {"name": "", "input": {}, "toolUseId": "123"}}, # Incomplete - ] + ], } - - exception = MaxTokensReachedException( - message="Token limit reached", - incomplete_message=incomplete_message - ) - + + exception = MaxTokensReachedException(message="Token limit reached", incomplete_message=incomplete_message) + recover_from_max_tokens_reached(agent, exception) - + # Should add one corrected message assert len(agent.messages) == initial_message_count + 1 - + # Check the corrected message content corrected_message = agent.messages[-1] assert corrected_message["role"] == "assistant" assert len(corrected_message["content"]) == 3 - + # First two content blocks should be preserved exactly assert corrected_message["content"][0] == {"text": "Here's some text."} assert corrected_message["content"][1] == {"image": {"format": "png", "source": {"bytes": "fake_image_data"}}} - + # Third content block should be replaced with error message assert "text" in corrected_message["content"][2] assert "" in corrected_message["content"][2]["text"] diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index 9dd802f4e..87aafe7a2 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -557,14 +557,11 @@ def test_agent__call__max_tokens_reached_triggers_conversation_manager_recovery( "content": [ {"text": "I'll help you with that."}, {"toolUse": {"name": "calculator", "input": {}, "toolUseId": ""}}, # Missing toolUseId - ] + ], } mock_model.mock_stream.side_effect = [ - MaxTokensReachedException( - message="Token limit reached", - incomplete_message=incomplete_message - ), + MaxTokensReachedException(message="Token limit reached", incomplete_message=incomplete_message), agenerator( [ {"contentBlockStart": {"start": {}}}, @@ -579,16 +576,16 @@ def test_agent__call__max_tokens_reached_triggers_conversation_manager_recovery( # Verify handle_token_limit_reached was called assert conversation_manager_spy.handle_token_limit_reached.call_count == 1 - + # Verify the call was made with the correct exception call_args = conversation_manager_spy.handle_token_limit_reached.call_args args, kwargs = call_args assert len(args) >= 2 # Should have at least agent and exception assert isinstance(args[1], MaxTokensReachedException) # Second argument should be the exception - + # Verify apply_management was also called assert conversation_manager_spy.apply_management.call_count > 0 - + # Verify the agent continued and produced a result assert result is not None @@ -601,12 +598,11 @@ def test_agent__call__max_tokens_reached_with_null_conversation_manager_raises_e "role": "assistant", "content": [ {"toolUse": {"name": "calculator", "input": {}, "toolUseId": ""}}, # Missing toolUseId - ] + ], } mock_model.mock_stream.side_effect = MaxTokensReachedException( - message="Token limit reached", - incomplete_message=incomplete_message + message="Token limit reached", incomplete_message=incomplete_message ) with pytest.raises(MaxTokensReachedException): diff --git a/tests/strands/agent/test_conversation_manager.py b/tests/strands/agent/test_conversation_manager.py index e3452824e..3e5bd56f3 100644 --- a/tests/strands/agent/test_conversation_manager.py +++ b/tests/strands/agent/test_conversation_manager.py @@ -4,8 +4,10 @@ from strands.agent.conversation_manager.null_conversation_manager import NullConversationManager from strands.agent.conversation_manager.sliding_window_conversation_manager import SlidingWindowConversationManager from strands.types.content import Message -from strands.types.exceptions import ContextWindowOverflowException, MaxTokensReachedException, MaxTokensReachedException -from strands.types.content import Message +from strands.types.exceptions import ( + ContextWindowOverflowException, + MaxTokensReachedException, +) @pytest.fixture @@ -211,33 +213,30 @@ def test_sliding_window_conversation_manager_handle_token_limit_reached(): manager = SlidingWindowConversationManager() test_agent = Agent() initial_message_count = len(test_agent.messages) - + incomplete_message: Message = { "role": "assistant", "content": [ {"text": "I'll help you with that."}, {"toolUse": {"name": "calculator", "input": {}, "toolUseId": ""}}, # Missing toolUseId - ] + ], } - - test_exception = MaxTokensReachedException( - message="Token limit reached", - incomplete_message=incomplete_message - ) - + + test_exception = MaxTokensReachedException(message="Token limit reached", incomplete_message=incomplete_message) + manager.handle_token_limit_reached(test_agent, test_exception) - + # Should add one corrected message assert len(test_agent.messages) == initial_message_count + 1 - + # Check the corrected message content corrected_message = test_agent.messages[-1] assert corrected_message["role"] == "assistant" assert len(corrected_message["content"]) == 2 - + # First content block should be preserved assert corrected_message["content"][0] == {"text": "I'll help you with that."} - + # Second content block should be replaced with error message assert "text" in corrected_message["content"][1] assert "calculator" in corrected_message["content"][1]["text"] @@ -291,33 +290,30 @@ def test_null_conversation_does_not_restore_with_incorrect_state(): def test_summarizing_conversation_manager_handle_token_limit_reached(): """Test that SummarizingConversationManager handles token limit recovery.""" from strands.agent.conversation_manager.summarizing_conversation_manager import SummarizingConversationManager - + manager = SummarizingConversationManager() test_agent = Agent() initial_message_count = len(test_agent.messages) - + incomplete_message: Message = { "role": "assistant", "content": [ {"toolUse": {"name": "", "input": {}, "toolUseId": "123"}}, # Missing name - ] + ], } - - test_exception = MaxTokensReachedException( - message="Token limit reached", - incomplete_message=incomplete_message - ) - + + test_exception = MaxTokensReachedException(message="Token limit reached", incomplete_message=incomplete_message) + manager.handle_token_limit_reached(test_agent, test_exception) - + # Should add one corrected message assert len(test_agent.messages) == initial_message_count + 1 - + # Check the corrected message content corrected_message = test_agent.messages[-1] assert corrected_message["role"] == "assistant" assert len(corrected_message["content"]) == 1 - + # Content should be replaced with error message using assert "text" in corrected_message["content"][0] assert "" in corrected_message["content"][0]["text"] From 87445a3224af4d2000b65fc8972abe8d0b9c8220 Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Wed, 6 Aug 2025 09:45:46 -0400 Subject: [PATCH 13/23] fix: test contained incorrect assertions --- src/strands/agent/agent.py | 4 ++-- .../test_token_limit_recovery.py | 1 - tests/strands/agent/test_agent.py | 16 ++++++---------- 3 files changed, 8 insertions(+), 13 deletions(-) diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index e749183fc..044ff4e67 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -585,10 +585,10 @@ async def _execute_event_loop_cycle(self, invocation_state: dict[str, Any]) -> A return except ContextWindowOverflowException as e: # Try reducing the context size and retrying - self.conversation_manager.reduce_context(self, e=e) + self.conversation_manager.reduce_context(agent=self, e=e) except MaxTokensReachedException as e: # Recover conversation state after token limit exceeded, then continue with next cycle - self.conversation_manager.handle_token_limit_reached(self, e=e) + self.conversation_manager.handle_token_limit_reached(agent=self, e=e) # Sync agent after handling exception to keep conversation_manager_state up to date in the session if self._session_manager: diff --git a/tests/strands/agent/conversation_manager/test_token_limit_recovery.py b/tests/strands/agent/conversation_manager/test_token_limit_recovery.py index 9ae6c8722..afbe73a39 100644 --- a/tests/strands/agent/conversation_manager/test_token_limit_recovery.py +++ b/tests/strands/agent/conversation_manager/test_token_limit_recovery.py @@ -1,6 +1,5 @@ """Tests for token limit recovery utility.""" - from strands.agent.agent import Agent from strands.agent.conversation_manager.token_limit_recovery import recover_from_max_tokens_reached from strands.types.content import Message diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index 87aafe7a2..1bc5ad78a 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -561,7 +561,9 @@ def test_agent__call__max_tokens_reached_triggers_conversation_manager_recovery( } mock_model.mock_stream.side_effect = [ + # First occurrence MaxTokensReachedException(message="Token limit reached", incomplete_message=incomplete_message), + # On retry the loop should succeed agenerator( [ {"contentBlockStart": {"start": {}}}, @@ -572,22 +574,16 @@ def test_agent__call__max_tokens_reached_triggers_conversation_manager_recovery( ), ] - result = agent("Test message") + agent("Test message") # Verify handle_token_limit_reached was called assert conversation_manager_spy.handle_token_limit_reached.call_count == 1 # Verify the call was made with the correct exception call_args = conversation_manager_spy.handle_token_limit_reached.call_args - args, kwargs = call_args - assert len(args) >= 2 # Should have at least agent and exception - assert isinstance(args[1], MaxTokensReachedException) # Second argument should be the exception - - # Verify apply_management was also called - assert conversation_manager_spy.apply_management.call_count > 0 - - # Verify the agent continued and produced a result - assert result is not None + kwargs = list(call_args[1].values()) + assert isinstance(kwargs[0], Agent) + assert isinstance(kwargs[1], MaxTokensReachedException) def test_agent__call__max_tokens_reached_with_null_conversation_manager_raises_exception(mock_model, agent): From 924fea9e68ecca7d310bbad1a8b9d607562b76a6 Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Wed, 6 Aug 2025 10:26:50 -0400 Subject: [PATCH 14/23] fix: add event emission --- .../agent/conversation_manager/__init__.py | 2 + .../conversation_manager.py | 5 +- .../null_conversation_manager.py | 27 +++--- ...recover_tool_use_on_max_tokens_reached.py} | 6 +- .../sliding_window_conversation_manager.py | 4 +- .../summarizing_conversation_manager.py | 4 +- .../test_token_limit_recovery.py | 83 ++++++++++++++++--- 7 files changed, 96 insertions(+), 35 deletions(-) rename src/strands/agent/conversation_manager/{token_limit_recovery.py => recover_tool_use_on_max_tokens_reached.py} (89%) diff --git a/src/strands/agent/conversation_manager/__init__.py b/src/strands/agent/conversation_manager/__init__.py index c59623215..7e7e0c6c5 100644 --- a/src/strands/agent/conversation_manager/__init__.py +++ b/src/strands/agent/conversation_manager/__init__.py @@ -15,12 +15,14 @@ from .conversation_manager import ConversationManager from .null_conversation_manager import NullConversationManager +from .recover_tool_use_on_max_tokens_reached import recover_tool_use_on_max_tokens_reached from .sliding_window_conversation_manager import SlidingWindowConversationManager from .summarizing_conversation_manager import SummarizingConversationManager __all__ = [ "ConversationManager", "NullConversationManager", + "recover_tool_use_on_max_tokens_reached", "SlidingWindowConversationManager", "SummarizingConversationManager", ] diff --git a/src/strands/agent/conversation_manager/conversation_manager.py b/src/strands/agent/conversation_manager/conversation_manager.py index c2899209b..57ce59f93 100644 --- a/src/strands/agent/conversation_manager/conversation_manager.py +++ b/src/strands/agent/conversation_manager/conversation_manager.py @@ -88,12 +88,11 @@ def reduce_context(self, agent: "Agent", e: Optional[Exception] = None, **kwargs """ pass - @abstractmethod def handle_token_limit_reached(self, agent: "Agent", e: MaxTokensReachedException, **kwargs: Any) -> None: """Called when MaxTokensReachedException is thrown to recover conversation state. This method should implement recovery strategies when the token limit is exceeded and the message array - may be in a broken state. It is called outside the event loop to apply default recovery mechanisms. + may be in a broken state. Args: agent: The agent whose conversation state will be recovered. @@ -101,4 +100,4 @@ def handle_token_limit_reached(self, agent: "Agent", e: MaxTokensReachedExceptio e: The MaxTokensReachedException that triggered the recovery. **kwargs: Additional keyword arguments for future extensibility. """ - pass + raise e diff --git a/src/strands/agent/conversation_manager/null_conversation_manager.py b/src/strands/agent/conversation_manager/null_conversation_manager.py index 29fa1c442..fb9868741 100644 --- a/src/strands/agent/conversation_manager/null_conversation_manager.py +++ b/src/strands/agent/conversation_manager/null_conversation_manager.py @@ -5,7 +5,7 @@ if TYPE_CHECKING: from ...agent.agent import Agent -from ...types.exceptions import ContextWindowOverflowException, MaxTokensReachedException +from ...types.exceptions import ContextWindowOverflowException from .conversation_manager import ConversationManager @@ -45,15 +45,16 @@ def reduce_context(self, agent: "Agent", e: Optional[Exception] = None, **kwargs else: raise ContextWindowOverflowException("Context window overflowed!") - def handle_token_limit_reached(self, agent: "Agent", e: MaxTokensReachedException, **kwargs: Any) -> None: - """Does not handle token limit recovery and raises the exception. - - Args: - agent: The agent whose conversation state will remain unmodified. - e: The MaxTokensReachedException that triggered the recovery. - **kwargs: Additional keyword arguments for future extensibility. - - Raises: - e: The provided exception. - """ - raise e + # + # def handle_token_limit_reached(self, agent: "Agent", e: MaxTokensReachedException, **kwargs: Any) -> None: + # """Does not handle token limit recovery and raises the exception. + # + # Args: + # agent: The agent whose conversation state will remain unmodified. + # e: The MaxTokensReachedException that triggered the recovery. + # **kwargs: Additional keyword arguments for future extensibility. + # + # Raises: + # e: The provided exception. + # """ + # raise e diff --git a/src/strands/agent/conversation_manager/token_limit_recovery.py b/src/strands/agent/conversation_manager/recover_tool_use_on_max_tokens_reached.py similarity index 89% rename from src/strands/agent/conversation_manager/token_limit_recovery.py rename to src/strands/agent/conversation_manager/recover_tool_use_on_max_tokens_reached.py index ceb32c735..516c3ec36 100644 --- a/src/strands/agent/conversation_manager/token_limit_recovery.py +++ b/src/strands/agent/conversation_manager/recover_tool_use_on_max_tokens_reached.py @@ -3,6 +3,7 @@ import logging from typing import TYPE_CHECKING +from ...hooks import MessageAddedEvent from ...types.content import ContentBlock, Message from ...types.exceptions import MaxTokensReachedException from ...types.tools import ToolUse @@ -13,7 +14,7 @@ logger = logging.getLogger(__name__) -def recover_from_max_tokens_reached(agent: "Agent", exception: MaxTokensReachedException) -> None: +def recover_tool_use_on_max_tokens_reached(agent: "Agent", exception: MaxTokensReachedException) -> None: """Handle MaxTokensReachedException by cleaning up orphaned tool uses and adding corrected message. This function fixes incomplete tool uses that may occur when the model's response is truncated @@ -28,7 +29,7 @@ def recover_from_max_tokens_reached(agent: "Agent", exception: MaxTokensReachedE agent: The agent whose conversation will be updated with the corrected message. exception: The MaxTokensReachedException containing the incomplete message. """ - logger.info("Handling MaxTokensReachedException - inspecting incomplete message for invalid tool uses") + logger.info("handling MaxTokensReachedException - inspecting incomplete message for invalid tool uses") incomplete_message: Message = exception.incomplete_message @@ -62,3 +63,4 @@ def recover_from_max_tokens_reached(agent: "Agent", exception: MaxTokensReachedE valid_message: Message = {"content": valid_content, "role": incomplete_message["role"]} agent.messages.append(valid_message) + agent.hooks.invoke_callbacks(MessageAddedEvent(agent=agent, message=valid_message)) diff --git a/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py b/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py index f96dbff27..0559e0efa 100644 --- a/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py +++ b/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py @@ -9,7 +9,7 @@ from ...types.content import Messages from ...types.exceptions import ContextWindowOverflowException, MaxTokensReachedException from .conversation_manager import ConversationManager -from .token_limit_recovery import recover_from_max_tokens_reached +from .recover_tool_use_on_max_tokens_reached import recover_tool_use_on_max_tokens_reached logger = logging.getLogger(__name__) @@ -187,4 +187,4 @@ def handle_token_limit_reached(self, agent: "Agent", e: MaxTokensReachedExceptio e: The MaxTokensReachedException that triggered the recovery. **kwargs: Additional keyword arguments for future extensibility. """ - recover_from_max_tokens_reached(agent, e) + recover_tool_use_on_max_tokens_reached(agent, e) diff --git a/src/strands/agent/conversation_manager/summarizing_conversation_manager.py b/src/strands/agent/conversation_manager/summarizing_conversation_manager.py index fe0d13fa4..1dc5d907a 100644 --- a/src/strands/agent/conversation_manager/summarizing_conversation_manager.py +++ b/src/strands/agent/conversation_manager/summarizing_conversation_manager.py @@ -8,7 +8,7 @@ from ...types.content import Message from ...types.exceptions import ContextWindowOverflowException, MaxTokensReachedException from .conversation_manager import ConversationManager -from .token_limit_recovery import recover_from_max_tokens_reached +from .recover_tool_use_on_max_tokens_reached import recover_tool_use_on_max_tokens_reached if TYPE_CHECKING: from ..agent import Agent @@ -260,4 +260,4 @@ def handle_token_limit_reached(self, agent: "Agent", e: MaxTokensReachedExceptio e: The MaxTokensReachedException that triggered the recovery. **kwargs: Additional keyword arguments for future extensibility. """ - recover_from_max_tokens_reached(agent, e) + recover_tool_use_on_max_tokens_reached(agent, e) diff --git a/tests/strands/agent/conversation_manager/test_token_limit_recovery.py b/tests/strands/agent/conversation_manager/test_token_limit_recovery.py index afbe73a39..006f5db25 100644 --- a/tests/strands/agent/conversation_manager/test_token_limit_recovery.py +++ b/tests/strands/agent/conversation_manager/test_token_limit_recovery.py @@ -1,14 +1,22 @@ """Tests for token limit recovery utility.""" +from unittest.mock import Mock + from strands.agent.agent import Agent -from strands.agent.conversation_manager.token_limit_recovery import recover_from_max_tokens_reached +from strands.agent.conversation_manager.recover_tool_use_on_max_tokens_reached import ( + recover_tool_use_on_max_tokens_reached, +) +from strands.hooks import MessageAddedEvent from strands.types.content import Message from strands.types.exceptions import MaxTokensReachedException -def test_recover_from_max_tokens_reached_with_incomplete_tool_use(): +def test_recover_tool_use_on_max_tokens_reached_with_incomplete_tool_use(): """Test recovery when incomplete tool use is present in the message.""" agent = Agent() + # Mock the hooks.invoke_callbacks method + mock_invoke_callbacks = Mock() + agent.hooks.invoke_callbacks = mock_invoke_callbacks initial_message_count = len(agent.messages) incomplete_message: Message = { @@ -21,7 +29,7 @@ def test_recover_from_max_tokens_reached_with_incomplete_tool_use(): exception = MaxTokensReachedException(message="Token limit reached", incomplete_message=incomplete_message) - recover_from_max_tokens_reached(agent, exception) + recover_tool_use_on_max_tokens_reached(agent, exception) # Should add one corrected message assert len(agent.messages) == initial_message_count + 1 @@ -39,10 +47,20 @@ def test_recover_from_max_tokens_reached_with_incomplete_tool_use(): assert "calculator" in corrected_message["content"][1]["text"] assert "incomplete due to maximum token limits" in corrected_message["content"][1]["text"] + # Verify that the MessageAddedEvent callback was invoked + mock_invoke_callbacks.assert_called_once() + call_args = mock_invoke_callbacks.call_args[0][0] + assert isinstance(call_args, MessageAddedEvent) + assert call_args.agent == agent + assert call_args.message == corrected_message + -def test_recover_from_max_tokens_reached_with_unknown_tool_name(): +def test_recover_tool_use_on_max_tokens_reached_with_unknown_tool_name(): """Test recovery when tool use has no name.""" agent = Agent() + # Mock the hooks.invoke_callbacks method + mock_invoke_callbacks = Mock() + agent.hooks.invoke_callbacks = mock_invoke_callbacks initial_message_count = len(agent.messages) incomplete_message: Message = { @@ -54,7 +72,7 @@ def test_recover_from_max_tokens_reached_with_unknown_tool_name(): exception = MaxTokensReachedException(message="Token limit reached", incomplete_message=incomplete_message) - recover_from_max_tokens_reached(agent, exception) + recover_tool_use_on_max_tokens_reached(agent, exception) # Should add one corrected message assert len(agent.messages) == initial_message_count + 1 @@ -69,10 +87,20 @@ def test_recover_from_max_tokens_reached_with_unknown_tool_name(): assert "" in corrected_message["content"][0]["text"] assert "incomplete due to maximum token limits" in corrected_message["content"][0]["text"] + # Verify that the MessageAddedEvent callback was invoked + mock_invoke_callbacks.assert_called_once() + call_args = mock_invoke_callbacks.call_args[0][0] + assert isinstance(call_args, MessageAddedEvent) + assert call_args.agent == agent + assert call_args.message == corrected_message -def test_recover_from_max_tokens_reached_with_valid_tool_use(): + +def test_recover_tool_use_on_max_tokens_reached_with_valid_tool_use(): """Test that valid tool uses are not modified and function returns early.""" agent = Agent() + # Mock the hooks.invoke_callbacks method + mock_invoke_callbacks = Mock() + agent.hooks.invoke_callbacks = mock_invoke_callbacks initial_message_count = len(agent.messages) incomplete_message: Message = { @@ -85,30 +113,42 @@ def test_recover_from_max_tokens_reached_with_valid_tool_use(): exception = MaxTokensReachedException(message="Token limit reached", incomplete_message=incomplete_message) - recover_from_max_tokens_reached(agent, exception) + recover_tool_use_on_max_tokens_reached(agent, exception) # Should not add any message since tool use was valid assert len(agent.messages) == initial_message_count + # Verify that the MessageAddedEvent callback was NOT invoked + mock_invoke_callbacks.assert_not_called() + -def test_recover_from_max_tokens_reached_with_empty_content(): +def test_recover_tool_use_on_max_tokens_reached_with_empty_content(): """Test that empty content is handled gracefully.""" agent = Agent() + # Mock the hooks.invoke_callbacks method + mock_invoke_callbacks = Mock() + agent.hooks.invoke_callbacks = mock_invoke_callbacks initial_message_count = len(agent.messages) incomplete_message: Message = {"role": "assistant", "content": []} exception = MaxTokensReachedException(message="Token limit reached", incomplete_message=incomplete_message) - recover_from_max_tokens_reached(agent, exception) + recover_tool_use_on_max_tokens_reached(agent, exception) # Should not add any message since content is empty assert len(agent.messages) == initial_message_count + # Verify that the MessageAddedEvent callback was NOT invoked + mock_invoke_callbacks.assert_not_called() -def test_recover_from_max_tokens_reached_with_mixed_content(): + +def test_recover_tool_use_on_max_tokens_reached_with_mixed_content(): """Test recovery with mix of valid content and incomplete tool use.""" agent = Agent() + # Mock the hooks.invoke_callbacks method + mock_invoke_callbacks = Mock() + agent.hooks.invoke_callbacks = mock_invoke_callbacks initial_message_count = len(agent.messages) incomplete_message: Message = { @@ -122,7 +162,7 @@ def test_recover_from_max_tokens_reached_with_mixed_content(): exception = MaxTokensReachedException(message="Token limit reached", incomplete_message=incomplete_message) - recover_from_max_tokens_reached(agent, exception) + recover_tool_use_on_max_tokens_reached(agent, exception) # Should add one corrected message assert len(agent.messages) == initial_message_count + 1 @@ -141,10 +181,20 @@ def test_recover_from_max_tokens_reached_with_mixed_content(): assert "calculator" in corrected_message["content"][1]["text"] assert "incomplete due to maximum token limits" in corrected_message["content"][1]["text"] + # Verify that the MessageAddedEvent callback was invoked + mock_invoke_callbacks.assert_called_once() + call_args = mock_invoke_callbacks.call_args[0][0] + assert isinstance(call_args, MessageAddedEvent) + assert call_args.agent == agent + assert call_args.message == corrected_message + -def test_recover_from_max_tokens_reached_preserves_non_tool_content(): +def test_recover_tool_use_on_max_tokens_reached_preserves_non_tool_content(): """Test that non-tool content is preserved as-is.""" agent = Agent() + # Mock the hooks.invoke_callbacks method + mock_invoke_callbacks = Mock() + agent.hooks.invoke_callbacks = mock_invoke_callbacks initial_message_count = len(agent.messages) incomplete_message: Message = { @@ -158,7 +208,7 @@ def test_recover_from_max_tokens_reached_preserves_non_tool_content(): exception = MaxTokensReachedException(message="Token limit reached", incomplete_message=incomplete_message) - recover_from_max_tokens_reached(agent, exception) + recover_tool_use_on_max_tokens_reached(agent, exception) # Should add one corrected message assert len(agent.messages) == initial_message_count + 1 @@ -175,3 +225,10 @@ def test_recover_from_max_tokens_reached_preserves_non_tool_content(): # Third content block should be replaced with error message assert "text" in corrected_message["content"][2] assert "" in corrected_message["content"][2]["text"] + + # Verify that the MessageAddedEvent callback was invoked + mock_invoke_callbacks.assert_called_once() + call_args = mock_invoke_callbacks.call_args[0][0] + assert isinstance(call_args, MessageAddedEvent) + assert call_args.agent == agent + assert call_args.message == corrected_message From 104f6b425fbbe5a2414b8f10281469d67d6ab1de Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Wed, 6 Aug 2025 14:03:48 -0400 Subject: [PATCH 15/23] feat: move to async --- src/strands/agent/agent.py | 2 +- .../conversation_manager.py | 2 +- .../null_conversation_manager.py | 14 ------- .../recover_tool_use_on_max_tokens_reached.py | 6 +-- .../sliding_window_conversation_manager.py | 20 +++++----- .../summarizing_conversation_manager.py | 20 +++++----- ...recover_tool_use_on_max_tokens_reached.py} | 38 ++++++++++++------- .../agent/test_conversation_manager.py | 15 +++++--- 8 files changed, 58 insertions(+), 59 deletions(-) rename tests/strands/agent/conversation_manager/{test_token_limit_recovery.py => test_recover_tool_use_on_max_tokens_reached.py} (86%) diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index 044ff4e67..1f63f7996 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -588,7 +588,7 @@ async def _execute_event_loop_cycle(self, invocation_state: dict[str, Any]) -> A self.conversation_manager.reduce_context(agent=self, e=e) except MaxTokensReachedException as e: # Recover conversation state after token limit exceeded, then continue with next cycle - self.conversation_manager.handle_token_limit_reached(agent=self, e=e) + await self.conversation_manager.handle_token_limit_reached(agent=self, e=e) # Sync agent after handling exception to keep conversation_manager_state up to date in the session if self._session_manager: diff --git a/src/strands/agent/conversation_manager/conversation_manager.py b/src/strands/agent/conversation_manager/conversation_manager.py index 57ce59f93..1b42e1fbf 100644 --- a/src/strands/agent/conversation_manager/conversation_manager.py +++ b/src/strands/agent/conversation_manager/conversation_manager.py @@ -88,7 +88,7 @@ def reduce_context(self, agent: "Agent", e: Optional[Exception] = None, **kwargs """ pass - def handle_token_limit_reached(self, agent: "Agent", e: MaxTokensReachedException, **kwargs: Any) -> None: + async def handle_token_limit_reached(self, agent: "Agent", e: MaxTokensReachedException, **kwargs: Any) -> None: """Called when MaxTokensReachedException is thrown to recover conversation state. This method should implement recovery strategies when the token limit is exceeded and the message array diff --git a/src/strands/agent/conversation_manager/null_conversation_manager.py b/src/strands/agent/conversation_manager/null_conversation_manager.py index fb9868741..5ff6874e5 100644 --- a/src/strands/agent/conversation_manager/null_conversation_manager.py +++ b/src/strands/agent/conversation_manager/null_conversation_manager.py @@ -44,17 +44,3 @@ def reduce_context(self, agent: "Agent", e: Optional[Exception] = None, **kwargs raise e else: raise ContextWindowOverflowException("Context window overflowed!") - - # - # def handle_token_limit_reached(self, agent: "Agent", e: MaxTokensReachedException, **kwargs: Any) -> None: - # """Does not handle token limit recovery and raises the exception. - # - # Args: - # agent: The agent whose conversation state will remain unmodified. - # e: The MaxTokensReachedException that triggered the recovery. - # **kwargs: Additional keyword arguments for future extensibility. - # - # Raises: - # e: The provided exception. - # """ - # raise e diff --git a/src/strands/agent/conversation_manager/recover_tool_use_on_max_tokens_reached.py b/src/strands/agent/conversation_manager/recover_tool_use_on_max_tokens_reached.py index 516c3ec36..e9e056a69 100644 --- a/src/strands/agent/conversation_manager/recover_tool_use_on_max_tokens_reached.py +++ b/src/strands/agent/conversation_manager/recover_tool_use_on_max_tokens_reached.py @@ -14,7 +14,7 @@ logger = logging.getLogger(__name__) -def recover_tool_use_on_max_tokens_reached(agent: "Agent", exception: MaxTokensReachedException) -> None: +async def recover_tool_use_on_max_tokens_reached(agent: "Agent", exception: MaxTokensReachedException) -> None: """Handle MaxTokensReachedException by cleaning up orphaned tool uses and adding corrected message. This function fixes incomplete tool uses that may occur when the model's response is truncated @@ -35,7 +35,7 @@ def recover_tool_use_on_max_tokens_reached(agent: "Agent", exception: MaxTokensR if not incomplete_message["content"]: # Cannot correct invalid content block if content is empty - return + raise exception valid_content: list[ContentBlock] = [] for content in incomplete_message["content"]: @@ -59,7 +59,7 @@ def recover_tool_use_on_max_tokens_reached(agent: "Agent", exception: MaxTokensR ) else: # ToolUse was invalid for an unknown reason. Cannot correct, return without modifying - return + raise exception valid_message: Message = {"content": valid_content, "role": incomplete_message["role"]} agent.messages.append(valid_message) diff --git a/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py b/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py index 0559e0efa..58710493d 100644 --- a/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py +++ b/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py @@ -113,6 +113,16 @@ def reduce_context(self, agent: "Agent", e: Optional[Exception] = None, **kwargs # Overwrite message history messages[:] = messages[trim_index:] + async def handle_token_limit_reached(self, agent: "Agent", e: MaxTokensReachedException, **kwargs: Any) -> None: + """Apply sliding window strategy for token limit recovery. + + Args: + agent: The agent whose conversation state will be recovered. + e: The MaxTokensReachedException that triggered the recovery. + **kwargs: Additional keyword arguments for future extensibility. + """ + await recover_tool_use_on_max_tokens_reached(agent, e) + def _truncate_tool_results(self, messages: Messages, msg_idx: int) -> bool: """Truncate tool results in a message to reduce context size. @@ -178,13 +188,3 @@ def _find_last_message_with_tool_results(self, messages: Messages) -> Optional[i return idx return None - - def handle_token_limit_reached(self, agent: "Agent", e: MaxTokensReachedException, **kwargs: Any) -> None: - """Apply sliding window strategy for token limit recovery. - - Args: - agent: The agent whose conversation state will be recovered. - e: The MaxTokensReachedException that triggered the recovery. - **kwargs: Additional keyword arguments for future extensibility. - """ - recover_tool_use_on_max_tokens_reached(agent, e) diff --git a/src/strands/agent/conversation_manager/summarizing_conversation_manager.py b/src/strands/agent/conversation_manager/summarizing_conversation_manager.py index 1dc5d907a..1c3dc7d38 100644 --- a/src/strands/agent/conversation_manager/summarizing_conversation_manager.py +++ b/src/strands/agent/conversation_manager/summarizing_conversation_manager.py @@ -167,6 +167,16 @@ def reduce_context(self, agent: "Agent", e: Optional[Exception] = None, **kwargs logger.error("Summarization failed: %s", summarization_error) raise summarization_error from e + async def handle_token_limit_reached(self, agent: "Agent", e: MaxTokensReachedException, **kwargs: Any) -> None: + """Apply summarization strategy for token limit recovery. + + Args: + agent: The agent whose conversation state will be recovered. + e: The MaxTokensReachedException that triggered the recovery. + **kwargs: Additional keyword arguments for future extensibility. + """ + await recover_tool_use_on_max_tokens_reached(agent, e) + def _generate_summary(self, messages: List[Message], agent: "Agent") -> Message: """Generate a summary of the provided messages. @@ -251,13 +261,3 @@ def _adjust_split_point_for_tool_pairs(self, messages: List[Message], split_poin raise ContextWindowOverflowException("Unable to trim conversation context!") return split_point - - def handle_token_limit_reached(self, agent: "Agent", e: MaxTokensReachedException, **kwargs: Any) -> None: - """Apply summarization strategy for token limit recovery. - - Args: - agent: The agent whose conversation state will be recovered. - e: The MaxTokensReachedException that triggered the recovery. - **kwargs: Additional keyword arguments for future extensibility. - """ - recover_tool_use_on_max_tokens_reached(agent, e) diff --git a/tests/strands/agent/conversation_manager/test_token_limit_recovery.py b/tests/strands/agent/conversation_manager/test_recover_tool_use_on_max_tokens_reached.py similarity index 86% rename from tests/strands/agent/conversation_manager/test_token_limit_recovery.py rename to tests/strands/agent/conversation_manager/test_recover_tool_use_on_max_tokens_reached.py index 006f5db25..77fc35c39 100644 --- a/tests/strands/agent/conversation_manager/test_token_limit_recovery.py +++ b/tests/strands/agent/conversation_manager/test_recover_tool_use_on_max_tokens_reached.py @@ -2,6 +2,8 @@ from unittest.mock import Mock +import pytest + from strands.agent.agent import Agent from strands.agent.conversation_manager.recover_tool_use_on_max_tokens_reached import ( recover_tool_use_on_max_tokens_reached, @@ -11,7 +13,8 @@ from strands.types.exceptions import MaxTokensReachedException -def test_recover_tool_use_on_max_tokens_reached_with_incomplete_tool_use(): +@pytest.mark.asyncio +async def test_recover_tool_use_on_max_tokens_reached_with_incomplete_tool_use(): """Test recovery when incomplete tool use is present in the message.""" agent = Agent() # Mock the hooks.invoke_callbacks method @@ -29,7 +32,7 @@ def test_recover_tool_use_on_max_tokens_reached_with_incomplete_tool_use(): exception = MaxTokensReachedException(message="Token limit reached", incomplete_message=incomplete_message) - recover_tool_use_on_max_tokens_reached(agent, exception) + await recover_tool_use_on_max_tokens_reached(agent, exception) # Should add one corrected message assert len(agent.messages) == initial_message_count + 1 @@ -55,7 +58,8 @@ def test_recover_tool_use_on_max_tokens_reached_with_incomplete_tool_use(): assert call_args.message == corrected_message -def test_recover_tool_use_on_max_tokens_reached_with_unknown_tool_name(): +@pytest.mark.asyncio +async def test_recover_tool_use_on_max_tokens_reached_with_unknown_tool_name(): """Test recovery when tool use has no name.""" agent = Agent() # Mock the hooks.invoke_callbacks method @@ -72,7 +76,7 @@ def test_recover_tool_use_on_max_tokens_reached_with_unknown_tool_name(): exception = MaxTokensReachedException(message="Token limit reached", incomplete_message=incomplete_message) - recover_tool_use_on_max_tokens_reached(agent, exception) + await recover_tool_use_on_max_tokens_reached(agent, exception) # Should add one corrected message assert len(agent.messages) == initial_message_count + 1 @@ -95,8 +99,9 @@ def test_recover_tool_use_on_max_tokens_reached_with_unknown_tool_name(): assert call_args.message == corrected_message -def test_recover_tool_use_on_max_tokens_reached_with_valid_tool_use(): - """Test that valid tool uses are not modified and function returns early.""" +@pytest.mark.asyncio +async def test_recover_tool_use_on_max_tokens_reached_with_valid_tool_use(): + """Test that an exception that is raised without recoverability, re-raises exception.""" agent = Agent() # Mock the hooks.invoke_callbacks method mock_invoke_callbacks = Mock() @@ -113,7 +118,8 @@ def test_recover_tool_use_on_max_tokens_reached_with_valid_tool_use(): exception = MaxTokensReachedException(message="Token limit reached", incomplete_message=incomplete_message) - recover_tool_use_on_max_tokens_reached(agent, exception) + with pytest.raises(MaxTokensReachedException): + await recover_tool_use_on_max_tokens_reached(agent, exception) # Should not add any message since tool use was valid assert len(agent.messages) == initial_message_count @@ -122,8 +128,9 @@ def test_recover_tool_use_on_max_tokens_reached_with_valid_tool_use(): mock_invoke_callbacks.assert_not_called() -def test_recover_tool_use_on_max_tokens_reached_with_empty_content(): - """Test that empty content is handled gracefully.""" +@pytest.mark.asyncio +async def test_recover_tool_use_on_max_tokens_reached_with_empty_content(): + """Test that an exception that is raised without recoverability, re-raises exception.""" agent = Agent() # Mock the hooks.invoke_callbacks method mock_invoke_callbacks = Mock() @@ -134,7 +141,8 @@ def test_recover_tool_use_on_max_tokens_reached_with_empty_content(): exception = MaxTokensReachedException(message="Token limit reached", incomplete_message=incomplete_message) - recover_tool_use_on_max_tokens_reached(agent, exception) + with pytest.raises(MaxTokensReachedException): + await recover_tool_use_on_max_tokens_reached(agent, exception) # Should not add any message since content is empty assert len(agent.messages) == initial_message_count @@ -143,7 +151,8 @@ def test_recover_tool_use_on_max_tokens_reached_with_empty_content(): mock_invoke_callbacks.assert_not_called() -def test_recover_tool_use_on_max_tokens_reached_with_mixed_content(): +@pytest.mark.asyncio +async def test_recover_tool_use_on_max_tokens_reached_with_mixed_content(): """Test recovery with mix of valid content and incomplete tool use.""" agent = Agent() # Mock the hooks.invoke_callbacks method @@ -162,7 +171,7 @@ def test_recover_tool_use_on_max_tokens_reached_with_mixed_content(): exception = MaxTokensReachedException(message="Token limit reached", incomplete_message=incomplete_message) - recover_tool_use_on_max_tokens_reached(agent, exception) + await recover_tool_use_on_max_tokens_reached(agent, exception) # Should add one corrected message assert len(agent.messages) == initial_message_count + 1 @@ -189,7 +198,8 @@ def test_recover_tool_use_on_max_tokens_reached_with_mixed_content(): assert call_args.message == corrected_message -def test_recover_tool_use_on_max_tokens_reached_preserves_non_tool_content(): +@pytest.mark.asyncio +async def test_recover_tool_use_on_max_tokens_reached_preserves_non_tool_content(): """Test that non-tool content is preserved as-is.""" agent = Agent() # Mock the hooks.invoke_callbacks method @@ -208,7 +218,7 @@ def test_recover_tool_use_on_max_tokens_reached_preserves_non_tool_content(): exception = MaxTokensReachedException(message="Token limit reached", incomplete_message=incomplete_message) - recover_tool_use_on_max_tokens_reached(agent, exception) + await recover_tool_use_on_max_tokens_reached(agent, exception) # Should add one corrected message assert len(agent.messages) == initial_message_count + 1 diff --git a/tests/strands/agent/test_conversation_manager.py b/tests/strands/agent/test_conversation_manager.py index 3e5bd56f3..83af6c429 100644 --- a/tests/strands/agent/test_conversation_manager.py +++ b/tests/strands/agent/test_conversation_manager.py @@ -208,7 +208,8 @@ def test_sliding_window_conversation_manager_with_tool_results_truncated(): assert messages == expected_messages -def test_sliding_window_conversation_manager_handle_token_limit_reached(): +@pytest.mark.asyncio +async def test_sliding_window_conversation_manager_handle_token_limit_reached(): """Test that SlidingWindowConversationManager handles token limit recovery.""" manager = SlidingWindowConversationManager() test_agent = Agent() @@ -224,7 +225,7 @@ def test_sliding_window_conversation_manager_handle_token_limit_reached(): test_exception = MaxTokensReachedException(message="Token limit reached", incomplete_message=incomplete_message) - manager.handle_token_limit_reached(test_agent, test_exception) + await manager.handle_token_limit_reached(test_agent, test_exception) # Should add one corrected message assert len(test_agent.messages) == initial_message_count + 1 @@ -287,7 +288,8 @@ def test_null_conversation_does_not_restore_with_incorrect_state(): manager.restore_from_session({}) -def test_summarizing_conversation_manager_handle_token_limit_reached(): +@pytest.mark.asyncio +async def test_summarizing_conversation_manager_handle_token_limit_reached(): """Test that SummarizingConversationManager handles token limit recovery.""" from strands.agent.conversation_manager.summarizing_conversation_manager import SummarizingConversationManager @@ -304,7 +306,7 @@ def test_summarizing_conversation_manager_handle_token_limit_reached(): test_exception = MaxTokensReachedException(message="Token limit reached", incomplete_message=incomplete_message) - manager.handle_token_limit_reached(test_agent, test_exception) + await manager.handle_token_limit_reached(test_agent, test_exception) # Should add one corrected message assert len(test_agent.messages) == initial_message_count + 1 @@ -320,7 +322,8 @@ def test_summarizing_conversation_manager_handle_token_limit_reached(): assert "incomplete due to maximum token limits" in corrected_message["content"][0]["text"] -def test_null_conversation_manager_handle_token_limit_reached_raises_exception(): +@pytest.mark.asyncio +async def test_null_conversation_manager_handle_token_limit_reached_raises_exception(): """Test that NullConversationManager raises the provided exception.""" manager = NullConversationManager() test_agent = Agent() @@ -331,4 +334,4 @@ def test_null_conversation_manager_handle_token_limit_reached_raises_exception() test_exception = MaxTokensReachedException(message="test", incomplete_message=test_message) with pytest.raises(MaxTokensReachedException): - manager.handle_token_limit_reached(test_agent, test_exception) + await manager.handle_token_limit_reached(test_agent, test_exception) From 11b91f417c95d50254e159793d1aff027ceacfbb Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Wed, 6 Aug 2025 14:15:52 -0400 Subject: [PATCH 16/23] feat: add additional error case where no tool uses were fixed --- .../recover_tool_use_on_max_tokens_reached.py | 6 ++++++ .../test_recover_tool_use_on_max_tokens_reached.py | 13 ++++++++++--- 2 files changed, 16 insertions(+), 3 deletions(-) diff --git a/src/strands/agent/conversation_manager/recover_tool_use_on_max_tokens_reached.py b/src/strands/agent/conversation_manager/recover_tool_use_on_max_tokens_reached.py index e9e056a69..8fddd4af5 100644 --- a/src/strands/agent/conversation_manager/recover_tool_use_on_max_tokens_reached.py +++ b/src/strands/agent/conversation_manager/recover_tool_use_on_max_tokens_reached.py @@ -38,6 +38,7 @@ async def recover_tool_use_on_max_tokens_reached(agent: "Agent", exception: MaxT raise exception valid_content: list[ContentBlock] = [] + has_corrected_content = False for content in incomplete_message["content"]: tool_use: ToolUse | None = content.get("toolUse") if not tool_use: @@ -57,10 +58,15 @@ async def recover_tool_use_on_max_tokens_reached(agent: "Agent", exception: MaxT f"to maximum token limits being reached." } ) + has_corrected_content = True else: # ToolUse was invalid for an unknown reason. Cannot correct, return without modifying raise exception + if not has_corrected_content: + # No ToolUse were modified, meaning this method could not have resolved the root cause + raise exception + valid_message: Message = {"content": valid_content, "role": incomplete_message["role"]} agent.messages.append(valid_message) agent.hooks.invoke_callbacks(MessageAddedEvent(agent=agent, message=valid_message)) diff --git a/tests/strands/agent/conversation_manager/test_recover_tool_use_on_max_tokens_reached.py b/tests/strands/agent/conversation_manager/test_recover_tool_use_on_max_tokens_reached.py index 77fc35c39..8fe576a87 100644 --- a/tests/strands/agent/conversation_manager/test_recover_tool_use_on_max_tokens_reached.py +++ b/tests/strands/agent/conversation_manager/test_recover_tool_use_on_max_tokens_reached.py @@ -128,8 +128,15 @@ async def test_recover_tool_use_on_max_tokens_reached_with_valid_tool_use(): mock_invoke_callbacks.assert_not_called() +@pytest.mark.parametrize( + "content,description", + [ + ([], "empty content"), + ([{"text": "Just some text with no tools to edit."}], "text-only content"), + ], +) @pytest.mark.asyncio -async def test_recover_tool_use_on_max_tokens_reached_with_empty_content(): +async def test_recover_tool_use_on_max_tokens_reached_with_empty_content(content, description): """Test that an exception that is raised without recoverability, re-raises exception.""" agent = Agent() # Mock the hooks.invoke_callbacks method @@ -137,14 +144,14 @@ async def test_recover_tool_use_on_max_tokens_reached_with_empty_content(): agent.hooks.invoke_callbacks = mock_invoke_callbacks initial_message_count = len(agent.messages) - incomplete_message: Message = {"role": "assistant", "content": []} + incomplete_message: Message = {"role": "assistant", "content": content} exception = MaxTokensReachedException(message="Token limit reached", incomplete_message=incomplete_message) with pytest.raises(MaxTokensReachedException): await recover_tool_use_on_max_tokens_reached(agent, exception) - # Should not add any message since content is empty + # Should not add any message since there's nothing to recover assert len(agent.messages) == initial_message_count # Verify that the MessageAddedEvent callback was NOT invoked From 1da9ba76c2ef131825d0fac077a7e6b88c88565e Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Wed, 6 Aug 2025 14:32:53 -0400 Subject: [PATCH 17/23] feat: add max tokens reached test --- .../recover_tool_use_on_max_tokens_reached.py | 5 +---- .../test_recover_tool_use_on_max_tokens_reached.py | 11 ++--------- 2 files changed, 3 insertions(+), 13 deletions(-) diff --git a/src/strands/agent/conversation_manager/recover_tool_use_on_max_tokens_reached.py b/src/strands/agent/conversation_manager/recover_tool_use_on_max_tokens_reached.py index 8fddd4af5..35d597e2a 100644 --- a/src/strands/agent/conversation_manager/recover_tool_use_on_max_tokens_reached.py +++ b/src/strands/agent/conversation_manager/recover_tool_use_on_max_tokens_reached.py @@ -32,6 +32,7 @@ async def recover_tool_use_on_max_tokens_reached(agent: "Agent", exception: MaxT logger.info("handling MaxTokensReachedException - inspecting incomplete message for invalid tool uses") incomplete_message: Message = exception.incomplete_message + logger.warning(f"incomplete message {incomplete_message}") if not incomplete_message["content"]: # Cannot correct invalid content block if content is empty @@ -63,10 +64,6 @@ async def recover_tool_use_on_max_tokens_reached(agent: "Agent", exception: MaxT # ToolUse was invalid for an unknown reason. Cannot correct, return without modifying raise exception - if not has_corrected_content: - # No ToolUse were modified, meaning this method could not have resolved the root cause - raise exception - valid_message: Message = {"content": valid_content, "role": incomplete_message["role"]} agent.messages.append(valid_message) agent.hooks.invoke_callbacks(MessageAddedEvent(agent=agent, message=valid_message)) diff --git a/tests/strands/agent/conversation_manager/test_recover_tool_use_on_max_tokens_reached.py b/tests/strands/agent/conversation_manager/test_recover_tool_use_on_max_tokens_reached.py index 8fe576a87..7d3770699 100644 --- a/tests/strands/agent/conversation_manager/test_recover_tool_use_on_max_tokens_reached.py +++ b/tests/strands/agent/conversation_manager/test_recover_tool_use_on_max_tokens_reached.py @@ -128,15 +128,8 @@ async def test_recover_tool_use_on_max_tokens_reached_with_valid_tool_use(): mock_invoke_callbacks.assert_not_called() -@pytest.mark.parametrize( - "content,description", - [ - ([], "empty content"), - ([{"text": "Just some text with no tools to edit."}], "text-only content"), - ], -) @pytest.mark.asyncio -async def test_recover_tool_use_on_max_tokens_reached_with_empty_content(content, description): +async def test_recover_tool_use_on_max_tokens_reached_with_empty_content(): """Test that an exception that is raised without recoverability, re-raises exception.""" agent = Agent() # Mock the hooks.invoke_callbacks method @@ -144,7 +137,7 @@ async def test_recover_tool_use_on_max_tokens_reached_with_empty_content(content agent.hooks.invoke_callbacks = mock_invoke_callbacks initial_message_count = len(agent.messages) - incomplete_message: Message = {"role": "assistant", "content": content} + incomplete_message: Message = {"role": "assistant", "content": []} exception = MaxTokensReachedException(message="Token limit reached", incomplete_message=incomplete_message) From 623f3c799c9f9fa844b3d86c4a19f086e66a60f3 Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Wed, 6 Aug 2025 14:34:37 -0400 Subject: [PATCH 18/23] linting --- .../recover_tool_use_on_max_tokens_reached.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/strands/agent/conversation_manager/recover_tool_use_on_max_tokens_reached.py b/src/strands/agent/conversation_manager/recover_tool_use_on_max_tokens_reached.py index 35d597e2a..8c1baa554 100644 --- a/src/strands/agent/conversation_manager/recover_tool_use_on_max_tokens_reached.py +++ b/src/strands/agent/conversation_manager/recover_tool_use_on_max_tokens_reached.py @@ -39,7 +39,6 @@ async def recover_tool_use_on_max_tokens_reached(agent: "Agent", exception: MaxT raise exception valid_content: list[ContentBlock] = [] - has_corrected_content = False for content in incomplete_message["content"]: tool_use: ToolUse | None = content.get("toolUse") if not tool_use: From 66c4c07f6a34ff59cb6d4ca864c63302392f2d53 Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Wed, 6 Aug 2025 14:38:18 -0400 Subject: [PATCH 19/23] feat: add max tokens reached test --- .../recover_tool_use_on_max_tokens_reached.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/strands/agent/conversation_manager/recover_tool_use_on_max_tokens_reached.py b/src/strands/agent/conversation_manager/recover_tool_use_on_max_tokens_reached.py index 8c1baa554..e9e056a69 100644 --- a/src/strands/agent/conversation_manager/recover_tool_use_on_max_tokens_reached.py +++ b/src/strands/agent/conversation_manager/recover_tool_use_on_max_tokens_reached.py @@ -32,7 +32,6 @@ async def recover_tool_use_on_max_tokens_reached(agent: "Agent", exception: MaxT logger.info("handling MaxTokensReachedException - inspecting incomplete message for invalid tool uses") incomplete_message: Message = exception.incomplete_message - logger.warning(f"incomplete message {incomplete_message}") if not incomplete_message["content"]: # Cannot correct invalid content block if content is empty @@ -58,7 +57,6 @@ async def recover_tool_use_on_max_tokens_reached(agent: "Agent", exception: MaxT f"to maximum token limits being reached." } ) - has_corrected_content = True else: # ToolUse was invalid for an unknown reason. Cannot correct, return without modifying raise exception From 4b5c5a72dae6617b66ef21672e98f89805849a3d Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Thu, 7 Aug 2025 16:50:00 -0400 Subject: [PATCH 20/23] feat: switch to a default behavior to recover from max tokens reached --- src/strands/agent/agent.py | 21 +- .../agent/conversation_manager/__init__.py | 2 - .../conversation_manager.py | 15 - .../recover_tool_use_on_max_tokens_reached.py | 66 ----- .../sliding_window_conversation_manager.py | 13 +- .../summarizing_conversation_manager.py | 13 +- .../_recover_message_on_max_tokens_reached.py | 76 +++++ src/strands/event_loop/event_loop.py | 32 ++- src/strands/types/exceptions.py | 6 +- .../agent/conversation_manager/__init__.py | 1 - ..._recover_tool_use_on_max_tokens_reached.py | 244 ---------------- tests/strands/agent/test_agent.py | 60 +--- .../agent/test_conversation_manager.py | 91 +----- tests/strands/event_loop/test_event_loop.py | 55 ++-- ...t_recover_message_on_max_tokens_reached.py | 267 ++++++++++++++++++ tests_integ/test_max_tokens_reached.py | 24 +- 16 files changed, 417 insertions(+), 569 deletions(-) delete mode 100644 src/strands/agent/conversation_manager/recover_tool_use_on_max_tokens_reached.py create mode 100644 src/strands/event_loop/_recover_message_on_max_tokens_reached.py delete mode 100644 tests/strands/agent/conversation_manager/__init__.py delete mode 100644 tests/strands/agent/conversation_manager/test_recover_tool_use_on_max_tokens_reached.py create mode 100644 tests/strands/event_loop/test_recover_message_on_max_tokens_reached.py diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index 1f63f7996..111509e3a 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -37,7 +37,7 @@ from ..tools.registry import ToolRegistry from ..tools.watcher import ToolWatcher from ..types.content import ContentBlock, Message, Messages -from ..types.exceptions import ContextWindowOverflowException, MaxTokensReachedException +from ..types.exceptions import ContextWindowOverflowException from ..types.tools import ToolResult, ToolUse from ..types.traces import AttributeValue from .agent_result import AgentResult @@ -582,21 +582,18 @@ async def _execute_event_loop_cycle(self, invocation_state: dict[str, Any]) -> A ) async for event in events: yield event - return + except ContextWindowOverflowException as e: # Try reducing the context size and retrying - self.conversation_manager.reduce_context(agent=self, e=e) - except MaxTokensReachedException as e: - # Recover conversation state after token limit exceeded, then continue with next cycle - await self.conversation_manager.handle_token_limit_reached(agent=self, e=e) + self.conversation_manager.reduce_context(self, e=e) - # Sync agent after handling exception to keep conversation_manager_state up to date in the session - if self._session_manager: - self._session_manager.sync_agent(self) + # Sync agent after reduce_context to keep conversation_manager_state up to date in the session + if self._session_manager: + self._session_manager.sync_agent(self) - events = self._execute_event_loop_cycle(invocation_state) - async for event in events: - yield event + events = self._execute_event_loop_cycle(invocation_state) + async for event in events: + yield event def _record_tool_execution( self, diff --git a/src/strands/agent/conversation_manager/__init__.py b/src/strands/agent/conversation_manager/__init__.py index 7e7e0c6c5..c59623215 100644 --- a/src/strands/agent/conversation_manager/__init__.py +++ b/src/strands/agent/conversation_manager/__init__.py @@ -15,14 +15,12 @@ from .conversation_manager import ConversationManager from .null_conversation_manager import NullConversationManager -from .recover_tool_use_on_max_tokens_reached import recover_tool_use_on_max_tokens_reached from .sliding_window_conversation_manager import SlidingWindowConversationManager from .summarizing_conversation_manager import SummarizingConversationManager __all__ = [ "ConversationManager", "NullConversationManager", - "recover_tool_use_on_max_tokens_reached", "SlidingWindowConversationManager", "SummarizingConversationManager", ] diff --git a/src/strands/agent/conversation_manager/conversation_manager.py b/src/strands/agent/conversation_manager/conversation_manager.py index 1b42e1fbf..2c1ee7847 100644 --- a/src/strands/agent/conversation_manager/conversation_manager.py +++ b/src/strands/agent/conversation_manager/conversation_manager.py @@ -4,7 +4,6 @@ from typing import TYPE_CHECKING, Any, Optional from ...types.content import Message -from ...types.exceptions import MaxTokensReachedException if TYPE_CHECKING: from ...agent.agent import Agent @@ -87,17 +86,3 @@ def reduce_context(self, agent: "Agent", e: Optional[Exception] = None, **kwargs **kwargs: Additional keyword arguments for future extensibility. """ pass - - async def handle_token_limit_reached(self, agent: "Agent", e: MaxTokensReachedException, **kwargs: Any) -> None: - """Called when MaxTokensReachedException is thrown to recover conversation state. - - This method should implement recovery strategies when the token limit is exceeded and the message array - may be in a broken state. - - Args: - agent: The agent whose conversation state will be recovered. - This list is modified in-place. - e: The MaxTokensReachedException that triggered the recovery. - **kwargs: Additional keyword arguments for future extensibility. - """ - raise e diff --git a/src/strands/agent/conversation_manager/recover_tool_use_on_max_tokens_reached.py b/src/strands/agent/conversation_manager/recover_tool_use_on_max_tokens_reached.py deleted file mode 100644 index e9e056a69..000000000 --- a/src/strands/agent/conversation_manager/recover_tool_use_on_max_tokens_reached.py +++ /dev/null @@ -1,66 +0,0 @@ -"""Shared utility for handling token limit recovery in conversation managers.""" - -import logging -from typing import TYPE_CHECKING - -from ...hooks import MessageAddedEvent -from ...types.content import ContentBlock, Message -from ...types.exceptions import MaxTokensReachedException -from ...types.tools import ToolUse - -if TYPE_CHECKING: - from ...agent.agent import Agent - -logger = logging.getLogger(__name__) - - -async def recover_tool_use_on_max_tokens_reached(agent: "Agent", exception: MaxTokensReachedException) -> None: - """Handle MaxTokensReachedException by cleaning up orphaned tool uses and adding corrected message. - - This function fixes incomplete tool uses that may occur when the model's response is truncated - due to token limits. It: - - 1. Inspects each content block in the incomplete message for invalid tool uses - 2. Replaces incomplete tool use blocks with informative text messages - 3. Preserves valid content blocks in the corrected message - 4. Adds the corrected message to the agent's conversation history - - Args: - agent: The agent whose conversation will be updated with the corrected message. - exception: The MaxTokensReachedException containing the incomplete message. - """ - logger.info("handling MaxTokensReachedException - inspecting incomplete message for invalid tool uses") - - incomplete_message: Message = exception.incomplete_message - - if not incomplete_message["content"]: - # Cannot correct invalid content block if content is empty - raise exception - - valid_content: list[ContentBlock] = [] - for content in incomplete_message["content"]: - tool_use: ToolUse | None = content.get("toolUse") - if not tool_use: - valid_content.append(content) - continue - - # Check if tool use is incomplete (missing or empty required fields) - tool_name = tool_use.get("name") - if not (tool_name and tool_use.get("input") and tool_use.get("toolUseId")): - # Tool use is incomplete due to max_tokens truncation - display_name = tool_name if tool_name else "" - logger.warning("tool_name=<%s> | replacing with error message due to max_tokens truncation.", display_name) - - valid_content.append( - { - "text": f"The selected tool {display_name}'s tool use was incomplete due " - f"to maximum token limits being reached." - } - ) - else: - # ToolUse was invalid for an unknown reason. Cannot correct, return without modifying - raise exception - - valid_message: Message = {"content": valid_content, "role": incomplete_message["role"]} - agent.messages.append(valid_message) - agent.hooks.invoke_callbacks(MessageAddedEvent(agent=agent, message=valid_message)) diff --git a/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py b/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py index 58710493d..e082abe8e 100644 --- a/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py +++ b/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py @@ -7,9 +7,8 @@ from ...agent.agent import Agent from ...types.content import Messages -from ...types.exceptions import ContextWindowOverflowException, MaxTokensReachedException +from ...types.exceptions import ContextWindowOverflowException from .conversation_manager import ConversationManager -from .recover_tool_use_on_max_tokens_reached import recover_tool_use_on_max_tokens_reached logger = logging.getLogger(__name__) @@ -113,16 +112,6 @@ def reduce_context(self, agent: "Agent", e: Optional[Exception] = None, **kwargs # Overwrite message history messages[:] = messages[trim_index:] - async def handle_token_limit_reached(self, agent: "Agent", e: MaxTokensReachedException, **kwargs: Any) -> None: - """Apply sliding window strategy for token limit recovery. - - Args: - agent: The agent whose conversation state will be recovered. - e: The MaxTokensReachedException that triggered the recovery. - **kwargs: Additional keyword arguments for future extensibility. - """ - await recover_tool_use_on_max_tokens_reached(agent, e) - def _truncate_tool_results(self, messages: Messages, msg_idx: int) -> bool: """Truncate tool results in a message to reduce context size. diff --git a/src/strands/agent/conversation_manager/summarizing_conversation_manager.py b/src/strands/agent/conversation_manager/summarizing_conversation_manager.py index 1c3dc7d38..60e832215 100644 --- a/src/strands/agent/conversation_manager/summarizing_conversation_manager.py +++ b/src/strands/agent/conversation_manager/summarizing_conversation_manager.py @@ -6,9 +6,8 @@ from typing_extensions import override from ...types.content import Message -from ...types.exceptions import ContextWindowOverflowException, MaxTokensReachedException +from ...types.exceptions import ContextWindowOverflowException from .conversation_manager import ConversationManager -from .recover_tool_use_on_max_tokens_reached import recover_tool_use_on_max_tokens_reached if TYPE_CHECKING: from ..agent import Agent @@ -167,16 +166,6 @@ def reduce_context(self, agent: "Agent", e: Optional[Exception] = None, **kwargs logger.error("Summarization failed: %s", summarization_error) raise summarization_error from e - async def handle_token_limit_reached(self, agent: "Agent", e: MaxTokensReachedException, **kwargs: Any) -> None: - """Apply summarization strategy for token limit recovery. - - Args: - agent: The agent whose conversation state will be recovered. - e: The MaxTokensReachedException that triggered the recovery. - **kwargs: Additional keyword arguments for future extensibility. - """ - await recover_tool_use_on_max_tokens_reached(agent, e) - def _generate_summary(self, messages: List[Message], agent: "Agent") -> Message: """Generate a summary of the provided messages. diff --git a/src/strands/event_loop/_recover_message_on_max_tokens_reached.py b/src/strands/event_loop/_recover_message_on_max_tokens_reached.py new file mode 100644 index 000000000..e4b208fdb --- /dev/null +++ b/src/strands/event_loop/_recover_message_on_max_tokens_reached.py @@ -0,0 +1,76 @@ +"""Message recovery utilities for handling max token limit scenarios. + +This module provides functionality to recover and clean up incomplete messages that occur +when model responses are truncated due to maximum token limits being reached. It specifically +handles cases where tool use blocks are incomplete or malformed due to truncation. +""" + +import logging + +from ..types.content import ContentBlock, Message +from ..types.tools import ToolUse + +logger = logging.getLogger(__name__) + + +def recover_message_on_max_tokens_reached(message: Message) -> Message: + """Recover and clean up incomplete messages when max token limits are reached. + + When a model response is truncated due to maximum token limits, tool use blocks may be + incomplete or malformed. This function inspects the message content and: + + 1. Identifies incomplete tool use blocks (missing name, input, or toolUseId) + 2. Replaces incomplete tool uses with informative error messages + 3. Preserves all valid content blocks (text and complete tool uses) + 4. Returns a cleaned message suitable for conversation history + + This recovery mechanism ensures that the conversation can continue gracefully even when + model responses are truncated, providing clear feedback about what happened. + + Args: + message: The potentially incomplete message from the model that was truncated + due to max token limits. + + Returns: + A cleaned Message with incomplete tool uses replaced by explanatory text content. + The returned message maintains the same role as the input message. + + Example: + If a message contains an incomplete tool use like: + ``` + {"toolUse": {"name": "calculator"}} # missing input and toolUseId + ``` + + It will be replaced with: + ``` + {"text": "The selected tool calculator's tool use was incomplete due to maximum token limits being reached."} + ``` + """ + logger.info("handling max_tokens stop reason - inspecting incomplete message for invalid tool uses") + + valid_content: list[ContentBlock] = [] + for content in message["content"] or []: + tool_use: ToolUse | None = content.get("toolUse") + if not tool_use: + valid_content.append(content) + continue + + # Check if tool use is incomplete (missing or empty required fields) + tool_name = tool_use.get("name") + if tool_name and tool_use.get("input") and tool_use.get("toolUseId"): + # As far as we can tell, tool use is valid if this condition is true + valid_content.append(content) + continue + + # Tool use is incomplete due to max_tokens truncation + display_name = tool_name if tool_name else "" + logger.warning("tool_name=<%s> | replacing with error message due to max_tokens truncation.", display_name) + + valid_content.append( + { + "text": f"The selected tool {display_name}'s tool use was incomplete due " + f"to maximum token limits being reached." + } + ) + + return {"content": valid_content, "role": message["role"]} diff --git a/src/strands/event_loop/event_loop.py b/src/strands/event_loop/event_loop.py index ae21d4c6d..b36f73155 100644 --- a/src/strands/event_loop/event_loop.py +++ b/src/strands/event_loop/event_loop.py @@ -36,6 +36,7 @@ ) from ..types.streaming import Metrics, StopReason from ..types.tools import ToolChoice, ToolChoiceAuto, ToolConfig, ToolGenerator, ToolResult, ToolUse +from ._recover_message_on_max_tokens_reached import recover_message_on_max_tokens_reached from .streaming import stream_messages if TYPE_CHECKING: @@ -156,6 +157,9 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) -> ) ) + if stop_reason == "max_tokens": + message = recover_message_on_max_tokens_reached(message) + if model_invoke_span: tracer.end_model_invoke_span(model_invoke_span, message, usage, stop_reason) break # Success! Break out of retry loop @@ -192,6 +196,19 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) -> raise e try: + # Add message in trace and mark the end of the stream messages trace + stream_trace.add_message(message) + stream_trace.end() + + # Add the response message to the conversation + agent.messages.append(message) + agent.hooks.invoke_callbacks(MessageAddedEvent(agent=agent, message=message)) + yield {"callback": {"message": message}} + + # Update metrics + agent.event_loop_metrics.update_usage(usage) + agent.event_loop_metrics.update_metrics(metrics) + if stop_reason == "max_tokens": """ Handle max_tokens limit reached by the model. @@ -205,21 +222,8 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) -> "Agent has reached an unrecoverable state due to max_tokens limit. " "For more information see: " "https://strandsagents.com/latest/user-guide/concepts/agents/agent-loop/#maxtokensreachedexception" - ), - incomplete_message=message, + ) ) - # Add message in trace and mark the end of the stream messages trace - stream_trace.add_message(message) - stream_trace.end() - - # Add the response message to the conversation - agent.messages.append(message) - agent.hooks.invoke_callbacks(MessageAddedEvent(agent=agent, message=message)) - yield {"callback": {"message": message}} - - # Update metrics - agent.event_loop_metrics.update_usage(usage) - agent.event_loop_metrics.update_metrics(metrics) # If the model is requesting to use tools if stop_reason == "tool_use": diff --git a/src/strands/types/exceptions.py b/src/strands/types/exceptions.py index 71ea28b9f..90f2b8d7f 100644 --- a/src/strands/types/exceptions.py +++ b/src/strands/types/exceptions.py @@ -2,8 +2,6 @@ from typing import Any -from strands.types.content import Message - class EventLoopException(Exception): """Exception raised by the event loop.""" @@ -28,14 +26,12 @@ class MaxTokensReachedException(Exception): the complexity of the response, or when the model naturally reaches its configured output limit during generation. """ - def __init__(self, message: str, incomplete_message: Message): + def __init__(self, message: str): """Initialize the exception with an error message and the incomplete message object. Args: message: The error message describing the token limit issue - incomplete_message: The valid Message object with incomplete content due to token limits """ - self.incomplete_message = incomplete_message super().__init__(message) diff --git a/tests/strands/agent/conversation_manager/__init__.py b/tests/strands/agent/conversation_manager/__init__.py deleted file mode 100644 index d5ee2d119..000000000 --- a/tests/strands/agent/conversation_manager/__init__.py +++ /dev/null @@ -1 +0,0 @@ -# Test package for conversation manager diff --git a/tests/strands/agent/conversation_manager/test_recover_tool_use_on_max_tokens_reached.py b/tests/strands/agent/conversation_manager/test_recover_tool_use_on_max_tokens_reached.py deleted file mode 100644 index 7d3770699..000000000 --- a/tests/strands/agent/conversation_manager/test_recover_tool_use_on_max_tokens_reached.py +++ /dev/null @@ -1,244 +0,0 @@ -"""Tests for token limit recovery utility.""" - -from unittest.mock import Mock - -import pytest - -from strands.agent.agent import Agent -from strands.agent.conversation_manager.recover_tool_use_on_max_tokens_reached import ( - recover_tool_use_on_max_tokens_reached, -) -from strands.hooks import MessageAddedEvent -from strands.types.content import Message -from strands.types.exceptions import MaxTokensReachedException - - -@pytest.mark.asyncio -async def test_recover_tool_use_on_max_tokens_reached_with_incomplete_tool_use(): - """Test recovery when incomplete tool use is present in the message.""" - agent = Agent() - # Mock the hooks.invoke_callbacks method - mock_invoke_callbacks = Mock() - agent.hooks.invoke_callbacks = mock_invoke_callbacks - initial_message_count = len(agent.messages) - - incomplete_message: Message = { - "role": "assistant", - "content": [ - {"text": "I'll help you with that."}, - {"toolUse": {"name": "calculator", "input": {}, "toolUseId": ""}}, # Missing toolUseId - ], - } - - exception = MaxTokensReachedException(message="Token limit reached", incomplete_message=incomplete_message) - - await recover_tool_use_on_max_tokens_reached(agent, exception) - - # Should add one corrected message - assert len(agent.messages) == initial_message_count + 1 - - # Check the corrected message content - corrected_message = agent.messages[-1] - assert corrected_message["role"] == "assistant" - assert len(corrected_message["content"]) == 2 - - # First content block should be preserved - assert corrected_message["content"][0] == {"text": "I'll help you with that."} - - # Second content block should be replaced with error message - assert "text" in corrected_message["content"][1] - assert "calculator" in corrected_message["content"][1]["text"] - assert "incomplete due to maximum token limits" in corrected_message["content"][1]["text"] - - # Verify that the MessageAddedEvent callback was invoked - mock_invoke_callbacks.assert_called_once() - call_args = mock_invoke_callbacks.call_args[0][0] - assert isinstance(call_args, MessageAddedEvent) - assert call_args.agent == agent - assert call_args.message == corrected_message - - -@pytest.mark.asyncio -async def test_recover_tool_use_on_max_tokens_reached_with_unknown_tool_name(): - """Test recovery when tool use has no name.""" - agent = Agent() - # Mock the hooks.invoke_callbacks method - mock_invoke_callbacks = Mock() - agent.hooks.invoke_callbacks = mock_invoke_callbacks - initial_message_count = len(agent.messages) - - incomplete_message: Message = { - "role": "assistant", - "content": [ - {"toolUse": {"name": "", "input": {}, "toolUseId": "123"}}, # Missing name - ], - } - - exception = MaxTokensReachedException(message="Token limit reached", incomplete_message=incomplete_message) - - await recover_tool_use_on_max_tokens_reached(agent, exception) - - # Should add one corrected message - assert len(agent.messages) == initial_message_count + 1 - - # Check the corrected message content - corrected_message = agent.messages[-1] - assert corrected_message["role"] == "assistant" - assert len(corrected_message["content"]) == 1 - - # Content should be replaced with error message using - assert "text" in corrected_message["content"][0] - assert "" in corrected_message["content"][0]["text"] - assert "incomplete due to maximum token limits" in corrected_message["content"][0]["text"] - - # Verify that the MessageAddedEvent callback was invoked - mock_invoke_callbacks.assert_called_once() - call_args = mock_invoke_callbacks.call_args[0][0] - assert isinstance(call_args, MessageAddedEvent) - assert call_args.agent == agent - assert call_args.message == corrected_message - - -@pytest.mark.asyncio -async def test_recover_tool_use_on_max_tokens_reached_with_valid_tool_use(): - """Test that an exception that is raised without recoverability, re-raises exception.""" - agent = Agent() - # Mock the hooks.invoke_callbacks method - mock_invoke_callbacks = Mock() - agent.hooks.invoke_callbacks = mock_invoke_callbacks - initial_message_count = len(agent.messages) - - incomplete_message: Message = { - "role": "assistant", - "content": [ - {"text": "I'll help you with that."}, - {"toolUse": {"name": "calculator", "input": {"expression": "2+2"}, "toolUseId": "123"}}, # Valid - ], - } - - exception = MaxTokensReachedException(message="Token limit reached", incomplete_message=incomplete_message) - - with pytest.raises(MaxTokensReachedException): - await recover_tool_use_on_max_tokens_reached(agent, exception) - - # Should not add any message since tool use was valid - assert len(agent.messages) == initial_message_count - - # Verify that the MessageAddedEvent callback was NOT invoked - mock_invoke_callbacks.assert_not_called() - - -@pytest.mark.asyncio -async def test_recover_tool_use_on_max_tokens_reached_with_empty_content(): - """Test that an exception that is raised without recoverability, re-raises exception.""" - agent = Agent() - # Mock the hooks.invoke_callbacks method - mock_invoke_callbacks = Mock() - agent.hooks.invoke_callbacks = mock_invoke_callbacks - initial_message_count = len(agent.messages) - - incomplete_message: Message = {"role": "assistant", "content": []} - - exception = MaxTokensReachedException(message="Token limit reached", incomplete_message=incomplete_message) - - with pytest.raises(MaxTokensReachedException): - await recover_tool_use_on_max_tokens_reached(agent, exception) - - # Should not add any message since there's nothing to recover - assert len(agent.messages) == initial_message_count - - # Verify that the MessageAddedEvent callback was NOT invoked - mock_invoke_callbacks.assert_not_called() - - -@pytest.mark.asyncio -async def test_recover_tool_use_on_max_tokens_reached_with_mixed_content(): - """Test recovery with mix of valid content and incomplete tool use.""" - agent = Agent() - # Mock the hooks.invoke_callbacks method - mock_invoke_callbacks = Mock() - agent.hooks.invoke_callbacks = mock_invoke_callbacks - initial_message_count = len(agent.messages) - - incomplete_message: Message = { - "role": "assistant", - "content": [ - {"text": "Let me calculate this for you."}, - {"toolUse": {"name": "calculator", "input": {}, "toolUseId": ""}}, # Incomplete - {"text": "And then I'll explain the result."}, - ], - } - - exception = MaxTokensReachedException(message="Token limit reached", incomplete_message=incomplete_message) - - await recover_tool_use_on_max_tokens_reached(agent, exception) - - # Should add one corrected message - assert len(agent.messages) == initial_message_count + 1 - - # Check the corrected message content - corrected_message = agent.messages[-1] - assert corrected_message["role"] == "assistant" - assert len(corrected_message["content"]) == 3 - - # First and third content blocks should be preserved - assert corrected_message["content"][0] == {"text": "Let me calculate this for you."} - assert corrected_message["content"][2] == {"text": "And then I'll explain the result."} - - # Second content block should be replaced with error message - assert "text" in corrected_message["content"][1] - assert "calculator" in corrected_message["content"][1]["text"] - assert "incomplete due to maximum token limits" in corrected_message["content"][1]["text"] - - # Verify that the MessageAddedEvent callback was invoked - mock_invoke_callbacks.assert_called_once() - call_args = mock_invoke_callbacks.call_args[0][0] - assert isinstance(call_args, MessageAddedEvent) - assert call_args.agent == agent - assert call_args.message == corrected_message - - -@pytest.mark.asyncio -async def test_recover_tool_use_on_max_tokens_reached_preserves_non_tool_content(): - """Test that non-tool content is preserved as-is.""" - agent = Agent() - # Mock the hooks.invoke_callbacks method - mock_invoke_callbacks = Mock() - agent.hooks.invoke_callbacks = mock_invoke_callbacks - initial_message_count = len(agent.messages) - - incomplete_message: Message = { - "role": "assistant", - "content": [ - {"text": "Here's some text."}, - {"image": {"format": "png", "source": {"bytes": "fake_image_data"}}}, - {"toolUse": {"name": "", "input": {}, "toolUseId": "123"}}, # Incomplete - ], - } - - exception = MaxTokensReachedException(message="Token limit reached", incomplete_message=incomplete_message) - - await recover_tool_use_on_max_tokens_reached(agent, exception) - - # Should add one corrected message - assert len(agent.messages) == initial_message_count + 1 - - # Check the corrected message content - corrected_message = agent.messages[-1] - assert corrected_message["role"] == "assistant" - assert len(corrected_message["content"]) == 3 - - # First two content blocks should be preserved exactly - assert corrected_message["content"][0] == {"text": "Here's some text."} - assert corrected_message["content"][1] == {"image": {"format": "png", "source": {"bytes": "fake_image_data"}}} - - # Third content block should be replaced with error message - assert "text" in corrected_message["content"][2] - assert "" in corrected_message["content"][2]["text"] - - # Verify that the MessageAddedEvent callback was invoked - mock_invoke_callbacks.assert_called_once() - call_args = mock_invoke_callbacks.call_args[0][0] - assert isinstance(call_args, MessageAddedEvent) - assert call_args.agent == agent - assert call_args.message == corrected_message diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index 1bc5ad78a..4e310dace 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -19,7 +19,7 @@ from strands.models.bedrock import DEFAULT_BEDROCK_MODEL_ID, BedrockModel from strands.session.repository_session_manager import RepositorySessionManager from strands.types.content import Messages -from strands.types.exceptions import ContextWindowOverflowException, EventLoopException, MaxTokensReachedException +from strands.types.exceptions import ContextWindowOverflowException, EventLoopException from strands.types.session import Session, SessionAgent, SessionMessage, SessionType from tests.fixtures.mock_session_repository import MockedSessionRepository from tests.fixtures.mocked_model_provider import MockedModelProvider @@ -547,64 +547,6 @@ def test_agent__call__tool_truncation_doesnt_infinite_loop(mock_model, agent): agent("Test!") -def test_agent__call__max_tokens_reached_triggers_conversation_manager_recovery(mock_model, agent, agenerator): - """Test that MaxTokensReachedException triggers conversation manager handle_token_limit_reached.""" - conversation_manager_spy = unittest.mock.Mock(wraps=agent.conversation_manager) - agent.conversation_manager = conversation_manager_spy - - incomplete_message = { - "role": "assistant", - "content": [ - {"text": "I'll help you with that."}, - {"toolUse": {"name": "calculator", "input": {}, "toolUseId": ""}}, # Missing toolUseId - ], - } - - mock_model.mock_stream.side_effect = [ - # First occurrence - MaxTokensReachedException(message="Token limit reached", incomplete_message=incomplete_message), - # On retry the loop should succeed - agenerator( - [ - {"contentBlockStart": {"start": {}}}, - {"contentBlockDelta": {"delta": {"text": "Recovered response"}}}, - {"contentBlockStop": {}}, - {"messageStop": {"stopReason": "end_turn"}}, - ] - ), - ] - - agent("Test message") - - # Verify handle_token_limit_reached was called - assert conversation_manager_spy.handle_token_limit_reached.call_count == 1 - - # Verify the call was made with the correct exception - call_args = conversation_manager_spy.handle_token_limit_reached.call_args - kwargs = list(call_args[1].values()) - assert isinstance(kwargs[0], Agent) - assert isinstance(kwargs[1], MaxTokensReachedException) - - -def test_agent__call__max_tokens_reached_with_null_conversation_manager_raises_exception(mock_model, agent): - """Test that MaxTokensReachedException with NullConversationManager raises the exception.""" - agent.conversation_manager = NullConversationManager() - - incomplete_message = { - "role": "assistant", - "content": [ - {"toolUse": {"name": "calculator", "input": {}, "toolUseId": ""}}, # Missing toolUseId - ], - } - - mock_model.mock_stream.side_effect = MaxTokensReachedException( - message="Token limit reached", incomplete_message=incomplete_message - ) - - with pytest.raises(MaxTokensReachedException): - agent("Test!") - - def test_agent__call__retry_with_overwritten_tool(mock_model, agent, tool, agenerator): conversation_manager_spy = unittest.mock.Mock(wraps=agent.conversation_manager) agent.conversation_manager = conversation_manager_spy diff --git a/tests/strands/agent/test_conversation_manager.py b/tests/strands/agent/test_conversation_manager.py index 83af6c429..77d7dcce8 100644 --- a/tests/strands/agent/test_conversation_manager.py +++ b/tests/strands/agent/test_conversation_manager.py @@ -3,11 +3,7 @@ from strands.agent.agent import Agent from strands.agent.conversation_manager.null_conversation_manager import NullConversationManager from strands.agent.conversation_manager.sliding_window_conversation_manager import SlidingWindowConversationManager -from strands.types.content import Message -from strands.types.exceptions import ( - ContextWindowOverflowException, - MaxTokensReachedException, -) +from strands.types.exceptions import ContextWindowOverflowException @pytest.fixture @@ -208,42 +204,6 @@ def test_sliding_window_conversation_manager_with_tool_results_truncated(): assert messages == expected_messages -@pytest.mark.asyncio -async def test_sliding_window_conversation_manager_handle_token_limit_reached(): - """Test that SlidingWindowConversationManager handles token limit recovery.""" - manager = SlidingWindowConversationManager() - test_agent = Agent() - initial_message_count = len(test_agent.messages) - - incomplete_message: Message = { - "role": "assistant", - "content": [ - {"text": "I'll help you with that."}, - {"toolUse": {"name": "calculator", "input": {}, "toolUseId": ""}}, # Missing toolUseId - ], - } - - test_exception = MaxTokensReachedException(message="Token limit reached", incomplete_message=incomplete_message) - - await manager.handle_token_limit_reached(test_agent, test_exception) - - # Should add one corrected message - assert len(test_agent.messages) == initial_message_count + 1 - - # Check the corrected message content - corrected_message = test_agent.messages[-1] - assert corrected_message["role"] == "assistant" - assert len(corrected_message["content"]) == 2 - - # First content block should be preserved - assert corrected_message["content"][0] == {"text": "I'll help you with that."} - - # Second content block should be replaced with error message - assert "text" in corrected_message["content"][1] - assert "calculator" in corrected_message["content"][1]["text"] - assert "incomplete due to maximum token limits" in corrected_message["content"][1]["text"] - - def test_null_conversation_manager_reduce_context_raises_context_window_overflow_exception(): """Test that NullConversationManager doesn't modify messages.""" manager = NullConversationManager() @@ -286,52 +246,3 @@ def test_null_conversation_does_not_restore_with_incorrect_state(): with pytest.raises(ValueError): manager.restore_from_session({}) - - -@pytest.mark.asyncio -async def test_summarizing_conversation_manager_handle_token_limit_reached(): - """Test that SummarizingConversationManager handles token limit recovery.""" - from strands.agent.conversation_manager.summarizing_conversation_manager import SummarizingConversationManager - - manager = SummarizingConversationManager() - test_agent = Agent() - initial_message_count = len(test_agent.messages) - - incomplete_message: Message = { - "role": "assistant", - "content": [ - {"toolUse": {"name": "", "input": {}, "toolUseId": "123"}}, # Missing name - ], - } - - test_exception = MaxTokensReachedException(message="Token limit reached", incomplete_message=incomplete_message) - - await manager.handle_token_limit_reached(test_agent, test_exception) - - # Should add one corrected message - assert len(test_agent.messages) == initial_message_count + 1 - - # Check the corrected message content - corrected_message = test_agent.messages[-1] - assert corrected_message["role"] == "assistant" - assert len(corrected_message["content"]) == 1 - - # Content should be replaced with error message using - assert "text" in corrected_message["content"][0] - assert "" in corrected_message["content"][0]["text"] - assert "incomplete due to maximum token limits" in corrected_message["content"][0]["text"] - - -@pytest.mark.asyncio -async def test_null_conversation_manager_handle_token_limit_reached_raises_exception(): - """Test that NullConversationManager raises the provided exception.""" - manager = NullConversationManager() - test_agent = Agent() - test_message: Message = { - "role": "assistant", - "content": [{"text": "Hello"}], - } - test_exception = MaxTokensReachedException(message="test", incomplete_message=test_message) - - with pytest.raises(MaxTokensReachedException): - await manager.handle_token_limit_reached(test_agent, test_exception) diff --git a/tests/strands/event_loop/test_event_loop.py b/tests/strands/event_loop/test_event_loop.py index 3886df8b9..191ab51ba 100644 --- a/tests/strands/event_loop/test_event_loop.py +++ b/tests/strands/event_loop/test_event_loop.py @@ -305,8 +305,10 @@ async def test_event_loop_cycle_text_response_error( await alist(stream) +@patch("strands.event_loop.event_loop.recover_message_on_max_tokens_reached") @pytest.mark.asyncio async def test_event_loop_cycle_tool_result( + mock_recover_message, agent, model, system_prompt, @@ -339,6 +341,9 @@ async def test_event_loop_cycle_tool_result( assert tru_stop_reason == exp_stop_reason and tru_message == exp_message and tru_request_state == exp_request_state + # Verify that recover_message_on_max_tokens_reached was NOT called for tool_use stop reason + mock_recover_message.assert_not_called() + model.stream.assert_called_with( [ {"role": "user", "content": [{"text": "Hello"}]}, @@ -568,25 +573,35 @@ async def test_event_loop_cycle_max_tokens_exception( agenerator, alist, ): - """Test that max_tokens stop reason raises MaxTokensReachedException.""" + """Test that max_tokens stop reason calls _recover_message_on_max_tokens_reached then MaxTokensReachedException.""" - # Note the empty toolUse to handle case raised in https://github.com/strands-agents/sdk-python/issues/495 - model.stream.return_value = agenerator( - [ - { - "contentBlockStart": { - "start": { - "toolUse": {}, + model.stream.side_effect = [ + agenerator( + [ + { + "contentBlockStart": { + "start": { + "toolUse": { + "toolUseId": "t1", + "name": "asdf", + "input": {}, # empty + }, + }, }, }, - }, - {"contentBlockStop": {}}, - {"messageStop": {"stopReason": "max_tokens"}}, - ] - ) + {"contentBlockStop": {}}, + {"messageStop": {"stopReason": "max_tokens"}}, + ] + ), + ] # Call event_loop_cycle, expecting it to raise MaxTokensReachedException - with pytest.raises(MaxTokensReachedException) as exc_info: + expected_message = ( + "Agent has reached an unrecoverable state due to max_tokens limit. " + "For more information see: " + "https://strandsagents.com/latest/user-guide/concepts/agents/agent-loop/#maxtokensreachedexception" + ) + with pytest.raises(MaxTokensReachedException, match=expected_message): stream = strands.event_loop.event_loop.event_loop_cycle( agent=agent, invocation_state={}, @@ -594,16 +609,8 @@ async def test_event_loop_cycle_max_tokens_exception( await alist(stream) # Verify the exception message contains the expected content - expected_message = ( - "Agent has reached an unrecoverable state due to max_tokens limit. " - "For more information see: " - "https://strandsagents.com/latest/user-guide/concepts/agents/agent-loop/#maxtokensreachedexception" - ) - assert str(exc_info.value) == expected_message - - # Verify that the message has not been appended to the messages array - assert len(agent.messages) == 1 - assert exc_info.value.incomplete_message not in agent.messages + assert len(agent.messages) == 2 + assert "tool use was incomplete due" in agent.messages[1]["content"][0]["text"] @patch("strands.event_loop.event_loop.get_tracer") diff --git a/tests/strands/event_loop/test_recover_message_on_max_tokens_reached.py b/tests/strands/event_loop/test_recover_message_on_max_tokens_reached.py new file mode 100644 index 000000000..e751be161 --- /dev/null +++ b/tests/strands/event_loop/test_recover_message_on_max_tokens_reached.py @@ -0,0 +1,267 @@ +"""Tests for token limit recovery utility.""" + +from strands.event_loop._recover_message_on_max_tokens_reached import ( + recover_message_on_max_tokens_reached, +) +from strands.types.content import Message + + +def test_recover_message_on_max_tokens_reached_with_incomplete_tool_use(): + """Test recovery when incomplete tool use is present in the message.""" + incomplete_message: Message = { + "role": "assistant", + "content": [ + {"text": "I'll help you with that."}, + {"toolUse": {"name": "calculator", "input": {}, "toolUseId": ""}}, # Missing toolUseId + ], + } + + result = recover_message_on_max_tokens_reached(incomplete_message) + + # Check the corrected message content + assert result["role"] == "assistant" + assert len(result["content"]) == 2 + + # First content block should be preserved + assert result["content"][0] == {"text": "I'll help you with that."} + + # Second content block should be replaced with error message + assert "text" in result["content"][1] + assert "calculator" in result["content"][1]["text"] + assert "incomplete due to maximum token limits" in result["content"][1]["text"] + + +def test_recover_message_on_max_tokens_reached_with_missing_tool_name(): + """Test recovery when tool use has no name.""" + incomplete_message: Message = { + "role": "assistant", + "content": [ + {"toolUse": {"name": "", "input": {}, "toolUseId": "123"}}, # Missing name + ], + } + + result = recover_message_on_max_tokens_reached(incomplete_message) + + # Check the corrected message content + assert result["role"] == "assistant" + assert len(result["content"]) == 1 + + # Content should be replaced with error message using + assert "text" in result["content"][0] + assert "" in result["content"][0]["text"] + assert "incomplete due to maximum token limits" in result["content"][0]["text"] + + +def test_recover_message_on_max_tokens_reached_with_missing_input(): + """Test recovery when tool use has no input.""" + incomplete_message: Message = { + "role": "assistant", + "content": [ + {"toolUse": {"name": "calculator", "toolUseId": "123"}}, # Missing input + ], + } + + result = recover_message_on_max_tokens_reached(incomplete_message) + + # Check the corrected message content + assert result["role"] == "assistant" + assert len(result["content"]) == 1 + + # Content should be replaced with error message + assert "text" in result["content"][0] + assert "calculator" in result["content"][0]["text"] + assert "incomplete due to maximum token limits" in result["content"][0]["text"] + + +def test_recover_message_on_max_tokens_reached_with_missing_tool_use_id(): + """Test recovery when tool use has no toolUseId.""" + incomplete_message: Message = { + "role": "assistant", + "content": [ + {"toolUse": {"name": "calculator", "input": {"expression": "2+2"}}}, # Missing toolUseId + ], + } + + result = recover_message_on_max_tokens_reached(incomplete_message) + + # Check the corrected message content + assert result["role"] == "assistant" + assert len(result["content"]) == 1 + + # Content should be replaced with error message + assert "text" in result["content"][0] + assert "calculator" in result["content"][0]["text"] + assert "incomplete due to maximum token limits" in result["content"][0]["text"] + + +def test_recover_message_on_max_tokens_reached_with_valid_tool_use(): + """Test that valid tool uses are preserved unchanged.""" + complete_message: Message = { + "role": "assistant", + "content": [ + {"text": "I'll help you with that."}, + {"toolUse": {"name": "calculator", "input": {"expression": "2+2"}, "toolUseId": "123"}}, # Valid + ], + } + + result = recover_message_on_max_tokens_reached(complete_message) + + # Should preserve the message exactly as-is + assert result["role"] == "assistant" + assert len(result["content"]) == 2 + assert result["content"][0] == {"text": "I'll help you with that."} + assert result["content"][1] == { + "toolUse": {"name": "calculator", "input": {"expression": "2+2"}, "toolUseId": "123"} + } + + +def test_recover_message_on_max_tokens_reached_with_empty_content(): + """Test handling of message with empty content.""" + empty_message: Message = {"role": "assistant", "content": []} + + result = recover_message_on_max_tokens_reached(empty_message) + + # Should return message with empty content preserved + assert result["role"] == "assistant" + assert result["content"] == [] + + +def test_recover_message_on_max_tokens_reached_with_none_content(): + """Test handling of message with None content.""" + none_content_message: Message = {"role": "assistant", "content": None} + + result = recover_message_on_max_tokens_reached(none_content_message) + + # Should return message with empty content + assert result["role"] == "assistant" + assert result["content"] == [] + + +def test_recover_message_on_max_tokens_reached_with_mixed_content(): + """Test recovery with mix of valid content and incomplete tool use.""" + incomplete_message: Message = { + "role": "assistant", + "content": [ + {"text": "Let me calculate this for you."}, + {"toolUse": {"name": "calculator", "input": {}, "toolUseId": ""}}, # Incomplete + {"text": "And then I'll explain the result."}, + ], + } + + result = recover_message_on_max_tokens_reached(incomplete_message) + + # Check the corrected message content + assert result["role"] == "assistant" + assert len(result["content"]) == 3 + + # First and third content blocks should be preserved + assert result["content"][0] == {"text": "Let me calculate this for you."} + assert result["content"][2] == {"text": "And then I'll explain the result."} + + # Second content block should be replaced with error message + assert "text" in result["content"][1] + assert "calculator" in result["content"][1]["text"] + assert "incomplete due to maximum token limits" in result["content"][1]["text"] + + +def test_recover_message_on_max_tokens_reached_preserves_non_tool_content(): + """Test that non-tool content is preserved as-is.""" + incomplete_message: Message = { + "role": "assistant", + "content": [ + {"text": "Here's some text."}, + {"image": {"format": "png", "source": {"bytes": "fake_image_data"}}}, + {"toolUse": {"name": "", "input": {}, "toolUseId": "123"}}, # Incomplete + ], + } + + result = recover_message_on_max_tokens_reached(incomplete_message) + + # Check the corrected message content + assert result["role"] == "assistant" + assert len(result["content"]) == 3 + + # First two content blocks should be preserved exactly + assert result["content"][0] == {"text": "Here's some text."} + assert result["content"][1] == {"image": {"format": "png", "source": {"bytes": "fake_image_data"}}} + + # Third content block should be replaced with error message + assert "text" in result["content"][2] + assert "" in result["content"][2]["text"] + assert "incomplete due to maximum token limits" in result["content"][2]["text"] + + +def test_recover_message_on_max_tokens_reached_multiple_incomplete_tools(): + """Test recovery with multiple incomplete tool uses.""" + incomplete_message: Message = { + "role": "assistant", + "content": [ + {"toolUse": {"name": "calculator", "input": {}}}, # Missing toolUseId + {"text": "Some text in between."}, + {"toolUse": {"name": "", "input": {}, "toolUseId": "456"}}, # Missing name + ], + } + + result = recover_message_on_max_tokens_reached(incomplete_message) + + # Check the corrected message content + assert result["role"] == "assistant" + assert len(result["content"]) == 3 + + # First tool use should be replaced + assert "text" in result["content"][0] + assert "calculator" in result["content"][0]["text"] + assert "incomplete due to maximum token limits" in result["content"][0]["text"] + + # Text content should be preserved + assert result["content"][1] == {"text": "Some text in between."} + + # Second tool use should be replaced with + assert "text" in result["content"][2] + assert "" in result["content"][2]["text"] + assert "incomplete due to maximum token limits" in result["content"][2]["text"] + + +def test_recover_message_on_max_tokens_reached_preserves_user_role(): + """Test that the function preserves the original message role.""" + incomplete_message: Message = { + "role": "user", + "content": [ + {"toolUse": {"name": "calculator", "input": {}}}, # Missing toolUseId + ], + } + + result = recover_message_on_max_tokens_reached(incomplete_message) + + # Should preserve the original role + assert result["role"] == "user" + assert len(result["content"]) == 1 + assert "text" in result["content"][0] + assert "calculator" in result["content"][0]["text"] + + +def test_recover_message_on_max_tokens_reached_with_content_without_tool_use(): + """Test handling of content blocks that don't have toolUse key.""" + message: Message = { + "role": "assistant", + "content": [ + {"text": "Regular text content."}, + {"someOtherKey": "someValue"}, # Content without toolUse + {"toolUse": {"name": "calculator"}}, # Incomplete tool use + ], + } + + result = recover_message_on_max_tokens_reached(message) + + # Check the corrected message content + assert result["role"] == "assistant" + assert len(result["content"]) == 3 + + # First two content blocks should be preserved + assert result["content"][0] == {"text": "Regular text content."} + assert result["content"][1] == {"someOtherKey": "someValue"} + + # Third content block should be replaced with error message + assert "text" in result["content"][2] + assert "calculator" in result["content"][2]["text"] + assert "incomplete due to maximum token limits" in result["content"][2]["text"] diff --git a/tests_integ/test_max_tokens_reached.py b/tests_integ/test_max_tokens_reached.py index d50452801..bf5668349 100644 --- a/tests_integ/test_max_tokens_reached.py +++ b/tests_integ/test_max_tokens_reached.py @@ -2,8 +2,8 @@ import pytest +from src.strands.agent import AgentResult from strands import Agent, tool -from strands.agent import NullConversationManager from strands.models.bedrock import BedrockModel from strands.types.exceptions import MaxTokensReachedException @@ -19,23 +19,14 @@ def story_tool(story: str) -> str: def test_max_tokens_reached(): + """Test that MaxTokensReachedException is raised but the agent can still rerun on the second pass""" model = BedrockModel(max_tokens=100) - agent = Agent(model=model, tools=[story_tool], conversation_manager=NullConversationManager()) + agent = Agent(model=model, tools=[story_tool]) + # This should raise an exception with pytest.raises(MaxTokensReachedException): agent("Tell me a story!") - assert len(agent.messages) == 1 - - -def test_max_tokens_reached_with_hook_provider(): - """Test that MaxTokensReachedException can be handled by a hook provider.""" - model = BedrockModel(max_tokens=100) - agent = Agent(model=model, tools=[story_tool]) # Defaults to include SlidingWindowConversationManager - - # This should NOT raise an exception because the hook handles it - agent("Tell me a story!") - # Validate that at least one message contains the incomplete tool use error message expected_text = "tool use was incomplete due to maximum token limits being reached" all_text_content = [ @@ -48,3 +39,10 @@ def test_max_tokens_reached_with_hook_provider(): assert any(expected_text in text for text in all_text_content), ( f"Expected to find message containing '{expected_text}' in agent messages" ) + + # Remove tools from agent and re-run with a generic question + agent.tool_registry.registry = {} + agent.tool_registry.tool_config = {} + + result: AgentResult = agent("What is 3+3") + assert result.stop_reason == "end_turn" From 83ad822de0e777a011ffc927286dd748f1e4cc69 Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Fri, 8 Aug 2025 09:59:41 -0400 Subject: [PATCH 21/23] fix: all tool uses now must be replaced --- .../_recover_message_on_max_tokens_reached.py | 38 +++++++++---------- ...t_recover_message_on_max_tokens_reached.py | 12 +++--- 2 files changed, 25 insertions(+), 25 deletions(-) diff --git a/src/strands/event_loop/_recover_message_on_max_tokens_reached.py b/src/strands/event_loop/_recover_message_on_max_tokens_reached.py index e4b208fdb..4282f319d 100644 --- a/src/strands/event_loop/_recover_message_on_max_tokens_reached.py +++ b/src/strands/event_loop/_recover_message_on_max_tokens_reached.py @@ -14,31 +14,36 @@ def recover_message_on_max_tokens_reached(message: Message) -> Message: - """Recover and clean up incomplete messages when max token limits are reached. + """Recover and clean up messages when max token limits are reached. - When a model response is truncated due to maximum token limits, tool use blocks may be - incomplete or malformed. This function inspects the message content and: + When a model response is truncated due to maximum token limits, all tool use blocks + should be replaced with informative error messages since they may be incomplete or + unreliable. This function inspects the message content and: - 1. Identifies incomplete tool use blocks (missing name, input, or toolUseId) - 2. Replaces incomplete tool uses with informative error messages - 3. Preserves all valid content blocks (text and complete tool uses) + 1. Identifies all tool use blocks (regardless of validity) + 2. Replaces all tool uses with informative error messages + 3. Preserves all non-tool content blocks (text, images, etc.) 4. Returns a cleaned message suitable for conversation history This recovery mechanism ensures that the conversation can continue gracefully even when - model responses are truncated, providing clear feedback about what happened. + model responses are truncated, providing clear feedback about what happened and preventing + potentially incomplete or corrupted tool executions. + + TODO: after https://github.com/strands-agents/sdk-python/issues/561 is completed, only the verifiable + invalid tool_use content blocks need to be replaced. Args: message: The potentially incomplete message from the model that was truncated due to max token limits. Returns: - A cleaned Message with incomplete tool uses replaced by explanatory text content. + A cleaned Message with all tool uses replaced by explanatory text content. The returned message maintains the same role as the input message. Example: - If a message contains an incomplete tool use like: + If a message contains any tool use (complete or incomplete): ``` - {"toolUse": {"name": "calculator"}} # missing input and toolUseId + {"toolUse": {"name": "calculator", "input": {"expression": "2+2"}, "toolUseId": "123"}} ``` It will be replaced with: @@ -46,7 +51,7 @@ def recover_message_on_max_tokens_reached(message: Message) -> Message: {"text": "The selected tool calculator's tool use was incomplete due to maximum token limits being reached."} ``` """ - logger.info("handling max_tokens stop reason - inspecting incomplete message for invalid tool uses") + logger.info("handling max_tokens stop reason - replacing all tool uses with error messages") valid_content: list[ContentBlock] = [] for content in message["content"] or []: @@ -55,15 +60,8 @@ def recover_message_on_max_tokens_reached(message: Message) -> Message: valid_content.append(content) continue - # Check if tool use is incomplete (missing or empty required fields) - tool_name = tool_use.get("name") - if tool_name and tool_use.get("input") and tool_use.get("toolUseId"): - # As far as we can tell, tool use is valid if this condition is true - valid_content.append(content) - continue - - # Tool use is incomplete due to max_tokens truncation - display_name = tool_name if tool_name else "" + # Replace all tool uses with error messages when max_tokens is reached + display_name = tool_use.get("name", "") logger.warning("tool_name=<%s> | replacing with error message due to max_tokens truncation.", display_name) valid_content.append( diff --git a/tests/strands/event_loop/test_recover_message_on_max_tokens_reached.py b/tests/strands/event_loop/test_recover_message_on_max_tokens_reached.py index e751be161..402e90966 100644 --- a/tests/strands/event_loop/test_recover_message_on_max_tokens_reached.py +++ b/tests/strands/event_loop/test_recover_message_on_max_tokens_reached.py @@ -95,7 +95,7 @@ def test_recover_message_on_max_tokens_reached_with_missing_tool_use_id(): def test_recover_message_on_max_tokens_reached_with_valid_tool_use(): - """Test that valid tool uses are preserved unchanged.""" + """Test that even valid tool uses are replaced with error messages.""" complete_message: Message = { "role": "assistant", "content": [ @@ -106,13 +106,15 @@ def test_recover_message_on_max_tokens_reached_with_valid_tool_use(): result = recover_message_on_max_tokens_reached(complete_message) - # Should preserve the message exactly as-is + # Should replace even valid tool uses with error messages assert result["role"] == "assistant" assert len(result["content"]) == 2 assert result["content"][0] == {"text": "I'll help you with that."} - assert result["content"][1] == { - "toolUse": {"name": "calculator", "input": {"expression": "2+2"}, "toolUseId": "123"} - } + + # Valid tool use should also be replaced with error message + assert "text" in result["content"][1] + assert "calculator" in result["content"][1]["text"] + assert "incomplete due to maximum token limits" in result["content"][1]["text"] def test_recover_message_on_max_tokens_reached_with_empty_content(): From faa4618197a33c7673fb9d66844f66bd795c9a5f Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Fri, 8 Aug 2025 10:03:18 -0400 Subject: [PATCH 22/23] fix: boolean --- .../event_loop/_recover_message_on_max_tokens_reached.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/strands/event_loop/_recover_message_on_max_tokens_reached.py b/src/strands/event_loop/_recover_message_on_max_tokens_reached.py index 4282f319d..74077042c 100644 --- a/src/strands/event_loop/_recover_message_on_max_tokens_reached.py +++ b/src/strands/event_loop/_recover_message_on_max_tokens_reached.py @@ -61,7 +61,7 @@ def recover_message_on_max_tokens_reached(message: Message) -> Message: continue # Replace all tool uses with error messages when max_tokens is reached - display_name = tool_use.get("name", "") + display_name = tool_use.get("name") or "" logger.warning("tool_name=<%s> | replacing with error message due to max_tokens truncation.", display_name) valid_content.append( From fa8195f186ff721f7044dac8db02517333bd17cf Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Fri, 8 Aug 2025 10:35:13 -0400 Subject: [PATCH 23/23] remove todo --- .../event_loop/_recover_message_on_max_tokens_reached.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/strands/event_loop/_recover_message_on_max_tokens_reached.py b/src/strands/event_loop/_recover_message_on_max_tokens_reached.py index 74077042c..ab6fb4abe 100644 --- a/src/strands/event_loop/_recover_message_on_max_tokens_reached.py +++ b/src/strands/event_loop/_recover_message_on_max_tokens_reached.py @@ -29,9 +29,6 @@ def recover_message_on_max_tokens_reached(message: Message) -> Message: model responses are truncated, providing clear feedback about what happened and preventing potentially incomplete or corrupted tool executions. - TODO: after https://github.com/strands-agents/sdk-python/issues/561 is completed, only the verifiable - invalid tool_use content blocks need to be replaced. - Args: message: The potentially incomplete message from the model that was truncated due to max token limits.