Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions src/strands/_exception_notes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
"""Exception note utilities for Python 3.10+ compatibility."""

# add_note was added in 3.11 - we hoist to a constant to facilitate testing
supports_add_note = hasattr(Exception, "add_note")


def add_exception_note(exception: Exception, note: str) -> None:
"""Add a note to an exception, compatible with Python 3.10+.
Uses add_note() if it's available (Python 3.11+) or modifies the exception message if it is not.
"""
if supports_add_note:
# we ignore the mypy error because the version-check for add_note is extracted into a constant up above and
# mypy doesn't detect that
exception.add_note(note) # type: ignore
else:
# For Python 3.10, append note to the exception message
if hasattr(exception, "args") and exception.args:
exception.args = (f"{exception.args[0]}\n{note}",) + exception.args[1:]
else:
exception.args = (note,)
47 changes: 24 additions & 23 deletions src/strands/models/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from pydantic import BaseModel
from typing_extensions import TypedDict, Unpack, override

from .._exception_notes import add_exception_note
from ..event_loop import streaming
from ..tools import convert_pydantic_to_tool_spec
from ..types.content import ContentBlock, Messages
Expand Down Expand Up @@ -716,29 +717,29 @@ def _stream(

region = self.client.meta.region_name

# add_note added in Python 3.11
if hasattr(e, "add_note"):
# Aid in debugging by adding more information
e.add_note(f"└ Bedrock region: {region}")
e.add_note(f"└ Model id: {self.config.get('model_id')}")

if (
e.response["Error"]["Code"] == "AccessDeniedException"
and "You don't have access to the model" in error_message
):
e.add_note(
"└ For more information see "
"https://strandsagents.com/latest/user-guide/concepts/model-providers/amazon-bedrock/#model-access-issue"
)

if (
e.response["Error"]["Code"] == "ValidationException"
and "with on-demand throughput isn’t supported" in error_message
):
e.add_note(
"└ For more information see "
"https://strandsagents.com/latest/user-guide/concepts/model-providers/amazon-bedrock/#on-demand-throughput-isnt-supported"
)
# Aid in debugging by adding more information
add_exception_note(e, f"└ Bedrock region: {region}")
add_exception_note(e, f"└ Model id: {self.config.get('model_id')}")

if (
e.response["Error"]["Code"] == "AccessDeniedException"
and "You don't have access to the model" in error_message
):
add_exception_note(
e,
"└ For more information see "
"https://strandsagents.com/latest/user-guide/concepts/model-providers/amazon-bedrock/#model-access-issue",
)

if (
e.response["Error"]["Code"] == "ValidationException"
and "with on-demand throughput isn’t supported" in error_message
):
add_exception_note(
e,
"└ For more information see "
"https://strandsagents.com/latest/user-guide/concepts/model-providers/amazon-bedrock/#on-demand-throughput-isnt-supported",
)

raise e

Expand Down
19 changes: 19 additions & 0 deletions tests/strands/models/test_bedrock.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import sys
import traceback
import unittest.mock
from unittest.mock import ANY

Expand All @@ -10,6 +11,7 @@
from botocore.exceptions import ClientError, EventStreamError

import strands
from strands import _exception_notes
from strands.models import BedrockModel
from strands.models.bedrock import (
_DEFAULT_BEDROCK_MODEL_ID,
Expand Down Expand Up @@ -1209,6 +1211,23 @@ async def test_add_note_on_client_error(bedrock_client, model, alist, messages):
assert err.value.__notes__ == ["└ Bedrock region: us-west-2", "└ Model id: m1"]


@pytest.mark.asyncio
async def test_add_note_on_client_error_without_add_notes(bedrock_client, model, alist, messages):
"""Test that when add_note is not used, the region & model are still included in the error output."""
with unittest.mock.patch.object(_exception_notes, "supports_add_note", False):
# Mock the client error response
error_response = {"Error": {"Code": "ValidationException", "Message": "Some error message"}}
bedrock_client.converse_stream.side_effect = ClientError(error_response, "ConversationStream")

# Call the stream method which should catch and add notes to the exception
with pytest.raises(ClientError) as err:
await alist(model.stream(messages))

error_str = "".join(traceback.format_exception(err.value))
assert "└ Bedrock region: us-west-2" in error_str
assert "└ Model id: m1" in error_str


@pytest.mark.asyncio
async def test_no_add_note_when_not_available(bedrock_client, model, alist, messages):
"""Verify that on any python version (even < 3.11 where add_note is not available, we get the right exception)."""
Expand Down
51 changes: 51 additions & 0 deletions tests/strands/test_exception_notes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
"""Tests for exception note utilities."""

import sys
import traceback
import unittest.mock

import pytest

from strands import _exception_notes
from strands._exception_notes import add_exception_note


@pytest.mark.skipif(sys.version_info < (3, 11), reason="This test requires Python 3.11 or higher (need add_note)")
def test_add_exception_note_python_311_plus():
"""Test add_exception_note uses add_note in Python 3.11+."""
exception = ValueError("original message")

add_exception_note(exception, "test note")

assert traceback.format_exception(exception) == ["ValueError: original message\n", "test note\n"]


def test_add_exception_note_python_310():
"""Test add_exception_note modifies args in Python 3.10."""
with unittest.mock.patch.object(_exception_notes, "supports_add_note", False):
exception = ValueError("original message")

add_exception_note(exception, "test note")

assert traceback.format_exception(exception) == ["ValueError: original message\ntest note\n"]


def test_add_exception_note_python_310_no_args():
"""Test add_exception_note handles exception with no args in Python 3.10."""
with unittest.mock.patch.object(_exception_notes, "supports_add_note", False):
exception = ValueError()
exception.args = ()

add_exception_note(exception, "test note")

assert traceback.format_exception(exception) == ["ValueError: test note\n"]


def test_add_exception_note_python_310_multiple_args():
"""Test add_exception_note preserves additional args in Python 3.10."""
with unittest.mock.patch.object(_exception_notes, "supports_add_note", False):
exception = ValueError("original message", "second arg")

add_exception_note(exception, "test note")

assert traceback.format_exception(exception) == ["ValueError: ('original message\\ntest note', 'second arg')\n"]