diff --git a/src/strands_tools/code_interpreter/__init__.py b/src/strands_tools/code_interpreter/__init__.py index 67313174..0591edde 100644 --- a/src/strands_tools/code_interpreter/__init__.py +++ b/src/strands_tools/code_interpreter/__init__.py @@ -17,10 +17,10 @@ Example: >>> from strands_tools.code_interpreter import AgentCoreCodeInterpreter - >>> + >>> >>> # Default usage >>> interpreter = AgentCoreCodeInterpreter(region="us-west-2") - >>> + >>> >>> # Custom identifier usage >>> custom_interpreter = AgentCoreCodeInterpreter( ... region="us-west-2", diff --git a/src/strands_tools/code_interpreter/agent_core_code_interpreter.py b/src/strands_tools/code_interpreter/agent_core_code_interpreter.py index a49aaad0..19d9cbe2 100644 --- a/src/strands_tools/code_interpreter/agent_core_code_interpreter.py +++ b/src/strands_tools/code_interpreter/agent_core_code_interpreter.py @@ -24,11 +24,11 @@ class SessionInfo: """ Information about a code interpreter session. - + This dataclass stores the essential information for managing active code interpreter sessions, including the session identifier, description, and the underlying Bedrock client instance. - + Attributes: session_id (str): Unique identifier for the session assigned by AWS Bedrock. description (str): Human-readable description of the session purpose. @@ -44,43 +44,43 @@ class SessionInfo: class AgentCoreCodeInterpreter(CodeInterpreter): """ Bedrock AgentCore implementation of the CodeInterpreter. - + This class provides a code interpreter interface using AWS Bedrock AgentCore services. It supports executing Python, JavaScript, and TypeScript code in isolated sandbox environments with custom code interpreter identifiers. - + The class maintains session state and provides methods for code execution, file operations, and session management. It supports both default AWS code interpreter environments and custom environments specified by identifier. - + Examples: Basic usage with default identifier: - + >>> interpreter = AgentCoreCodeInterpreter(region="us-west-2") >>> # Uses default identifier: "aws.codeinterpreter.v1" - + Using a custom code interpreter identifier: - + >>> custom_id = "my-custom-interpreter-abc123" >>> interpreter = AgentCoreCodeInterpreter( - ... region="us-west-2", + ... region="us-west-2", ... identifier=custom_id ... ) - + Environment-specific usage: - + >>> # For testing environments >>> test_interpreter = AgentCoreCodeInterpreter( ... region="us-east-1", ... identifier="test-interpreter-xyz789" ... ) - - >>> # For production environments + + >>> # For production environments >>> prod_interpreter = AgentCoreCodeInterpreter( ... region="us-west-2", ... identifier="prod-interpreter-def456" ... ) - + Attributes: region (str): The AWS region where the code interpreter service is hosted. identifier (str): The code interpreter identifier being used for sessions. @@ -97,24 +97,24 @@ def __init__(self, region: Optional[str] = None, identifier: Optional[str] = Non identifier (Optional[str]): Custom code interpreter identifier to use for code execution sessions. This allows you to specify custom code interpreter environments instead of the default AWS-provided one. - + Valid formats include: - Default identifier: "aws.codeinterpreter.v1" (used when None) - Custom identifier: "my-custom-interpreter-abc123" - Environment-specific: "test-interpreter-xyz789" - + Note: Use the code interpreter ID, not the full ARN. The AWS service expects the identifier portion only (e.g., "my-interpreter-123" rather than "arn:aws:bedrock-agentcore:region:account:code-interpreter-custom/my-interpreter-123"). - + If not provided, defaults to "aws.codeinterpreter.v1" for backward compatibility. Defaults to None. - + Note: This constructor maintains full backward compatibility. Existing code that doesn't specify the identifier parameter will continue to work unchanged with the default AWS code interpreter environment. - + Raises: Exception: If there are issues with AWS region resolution or client initialization during session creation. @@ -152,20 +152,20 @@ def cleanup_platform(self) -> None: def init_session(self, action: InitSessionAction) -> Dict[str, Any]: """ Initialize a new Bedrock AgentCore sandbox session. - + Creates a new code interpreter session using the configured identifier. The session will use the identifier specified during class initialization, or the default "aws.codeinterpreter.v1" if none was provided. - + Args: action (InitSessionAction): Action containing session initialization parameters including session_name and description. - + Returns: Dict[str, Any]: Response dictionary containing session information on success or error details on failure. Success response includes sessionName, description, and sessionId. - + Raises: Exception: If session initialization fails due to AWS service issues, invalid identifier, or other configuration problems. @@ -219,8 +219,8 @@ def init_session(self, action: InitSessionAction) -> Dict[str, Any]: f"Failed to initialize session '{session_name}' with identifier: {self.identifier}. Error: {str(e)}" ) return { - "status": "error", - "content": [{"text": f"Failed to initialize session '{session_name}': {str(e)}"}] + "status": "error", + "content": [{"text": f"Failed to initialize session '{session_name}': {str(e)}"}], } def list_local_sessions(self) -> Dict[str, Any]: diff --git a/src/strands_tools/retrieve.py b/src/strands_tools/retrieve.py index 9205dba0..ea558b24 100644 --- a/src/strands_tools/retrieve.py +++ b/src/strands_tools/retrieve.py @@ -187,14 +187,21 @@ def format_results_for_display(results: List[Dict[str, Any]]) -> str: results: List of retrieval results from Bedrock Knowledge Base Returns: - Formatted string containing the results in a readable format + Formatted string containing the results in a readable format, including score, document ID, and content. """ if not results: return "No results found above score threshold." formatted = [] for result in results: - doc_id = result.get("location", {}).get("customDocumentLocation", {}).get("id", "Unknown") + # Extract document location - handle both s3Location and customDocumentLocation + location = result.get("location", {}) + doc_id = "Unknown" + if "customDocumentLocation" in location: + doc_id = location["customDocumentLocation"].get("id", "Unknown") + elif "s3Location" in location: + # Extract meaningful part from S3 URI + doc_id = location["s3Location"].get("uri", "") score = result.get("score", 0.0) formatted.append(f"\nScore: {score:.4f}") formatted.append(f"Document ID: {doc_id}") diff --git a/tests/code_interpreter/test_agent_core_code_interpreter.py b/tests/code_interpreter/test_agent_core_code_interpreter.py index b291863b..7db5266b 100644 --- a/tests/code_interpreter/test_agent_core_code_interpreter.py +++ b/tests/code_interpreter/test_agent_core_code_interpreter.py @@ -64,7 +64,7 @@ def test_constructor_custom_identifier_initialization(): mock_resolve.return_value = "us-west-2" custom_id = "custom-interpreter-def456" interpreter = AgentCoreCodeInterpreter(region="us-west-2", identifier=custom_id) - + assert interpreter.region == "us-west-2" assert interpreter.identifier == custom_id assert interpreter._sessions == {} @@ -76,7 +76,7 @@ def test_constructor_default_identifier_fallback(): with patch("strands_tools.code_interpreter.agent_core_code_interpreter.resolve_region") as mock_resolve: mock_resolve.return_value = "us-west-2" interpreter = AgentCoreCodeInterpreter(region="us-west-2") - + assert interpreter.region == "us-west-2" assert interpreter.identifier == "aws.codeinterpreter.v1" assert interpreter._sessions == {} @@ -89,7 +89,7 @@ def test_constructor_backward_compatibility_region_only(): mock_resolve.return_value = "us-east-1" # This is how existing code would call the constructor interpreter = AgentCoreCodeInterpreter("us-east-1") - + assert interpreter.region == "us-east-1" assert interpreter.identifier == "aws.codeinterpreter.v1" assert interpreter._sessions == {} @@ -102,7 +102,7 @@ def test_constructor_backward_compatibility_no_params(): mock_resolve.return_value = "us-east-1" # This is how existing code would call the constructor interpreter = AgentCoreCodeInterpreter() - + assert interpreter.region == "us-east-1" assert interpreter.identifier == "aws.codeinterpreter.v1" assert interpreter._sessions == {} @@ -113,26 +113,26 @@ def test_constructor_instance_variable_storage_scenarios(): """Test that instance variable is set correctly in all scenarios.""" with patch("strands_tools.code_interpreter.agent_core_code_interpreter.resolve_region") as mock_resolve: mock_resolve.return_value = "us-west-2" - + # Scenario 1: Custom identifier provided custom_id = "test.codeinterpreter.v1" interpreter1 = AgentCoreCodeInterpreter(region="us-west-2", identifier=custom_id) - assert hasattr(interpreter1, 'identifier') + assert hasattr(interpreter1, "identifier") assert interpreter1.identifier == custom_id - + # Scenario 2: None identifier provided (explicit None) interpreter2 = AgentCoreCodeInterpreter(region="us-west-2", identifier=None) - assert hasattr(interpreter2, 'identifier') + assert hasattr(interpreter2, "identifier") assert interpreter2.identifier == "aws.codeinterpreter.v1" - + # Scenario 3: Empty string identifier provided interpreter3 = AgentCoreCodeInterpreter(region="us-west-2", identifier="") - assert hasattr(interpreter3, 'identifier') + assert hasattr(interpreter3, "identifier") assert interpreter3.identifier == "aws.codeinterpreter.v1" - + # Scenario 4: No identifier parameter provided interpreter4 = AgentCoreCodeInterpreter(region="us-west-2") - assert hasattr(interpreter4, 'identifier') + assert hasattr(interpreter4, "identifier") assert interpreter4.identifier == "aws.codeinterpreter.v1" @@ -142,7 +142,7 @@ def test_constructor_custom_identifier_with_complex_format(): mock_resolve.return_value = "us-west-2" complex_id = "my-custom-interpreter-abc123-prod" interpreter = AgentCoreCodeInterpreter(region="us-west-2", identifier=complex_id) - + assert interpreter.region == "us-west-2" assert interpreter.identifier == complex_id assert interpreter._sessions == {} @@ -225,24 +225,24 @@ def test_init_session_with_custom_identifier(mock_client_class, mock_client): with patch("strands_tools.code_interpreter.agent_core_code_interpreter.resolve_region") as mock_resolve: mock_resolve.return_value = "us-west-2" mock_client_class.return_value = mock_client - + # Create interpreter with custom identifier custom_id = "my-custom-interpreter-abc123" interpreter = AgentCoreCodeInterpreter(region="us-west-2", identifier=custom_id) - + action = InitSessionAction(type="initSession", description="Test session", session_name="custom-session") - + result = interpreter.init_session(action) - + assert result["status"] == "success" assert result["content"][0]["json"]["sessionName"] == "custom-session" assert result["content"][0]["json"]["description"] == "Test session" assert result["content"][0]["json"]["sessionId"] == "test-session-id-123" - + # Verify client was created and started with custom identifier mock_client_class.assert_called_once_with(region="us-west-2") mock_client.start.assert_called_once_with(identifier=custom_id) - + # Verify session was stored assert "custom-session" in interpreter._sessions session_info = interpreter._sessions["custom-session"] @@ -258,23 +258,23 @@ def test_init_session_with_default_identifier(mock_client_class, mock_client): with patch("strands_tools.code_interpreter.agent_core_code_interpreter.resolve_region") as mock_resolve: mock_resolve.return_value = "us-west-2" mock_client_class.return_value = mock_client - + # Create interpreter without custom identifier (should use default) interpreter = AgentCoreCodeInterpreter(region="us-west-2") - + action = InitSessionAction(type="initSession", description="Test session", session_name="default-session") - + result = interpreter.init_session(action) - + assert result["status"] == "success" assert result["content"][0]["json"]["sessionName"] == "default-session" assert result["content"][0]["json"]["description"] == "Test session" assert result["content"][0]["json"]["sessionId"] == "test-session-id-123" - + # Verify client was created and started with default identifier mock_client_class.assert_called_once_with(region="us-west-2") mock_client.start.assert_called_once_with(identifier="aws.codeinterpreter.v1") - + # Verify session was stored assert "default-session" in interpreter._sessions session_info = interpreter._sessions["default-session"] @@ -291,17 +291,17 @@ def test_init_session_logging_includes_identifier(mock_logger, mock_client_class with patch("strands_tools.code_interpreter.agent_core_code_interpreter.resolve_region") as mock_resolve: mock_resolve.return_value = "us-west-2" mock_client_class.return_value = mock_client - + # Test with custom identifier custom_id = "test.codeinterpreter.v1" interpreter = AgentCoreCodeInterpreter(region="us-west-2", identifier=custom_id) - + action = InitSessionAction(type="initSession", description="Test session", session_name="log-test-session") - + result = interpreter.init_session(action) - + assert result["status"] == "success" - + # Verify logging calls include identifier information mock_logger.info.assert_any_call( f"Initializing Bedrock AgentCoresandbox session: Test session with identifier: {custom_id}" @@ -318,16 +318,16 @@ def test_init_session_logging_includes_default_identifier(mock_logger, mock_clie with patch("strands_tools.code_interpreter.agent_core_code_interpreter.resolve_region") as mock_resolve: mock_resolve.return_value = "us-west-2" mock_client_class.return_value = mock_client - + # Test with default identifier (none provided) interpreter = AgentCoreCodeInterpreter(region="us-west-2") - + action = InitSessionAction(type="initSession", description="Test session", session_name="log-default-session") - + result = interpreter.init_session(action) - + assert result["status"] == "success" - + # Verify logging calls include default identifier information default_id = "aws.codeinterpreter.v1" mock_logger.info.assert_any_call( @@ -346,18 +346,18 @@ def test_init_session_error_logging_includes_identifier(mock_logger, mock_client mock_resolve.return_value = "us-west-2" mock_client.start.side_effect = Exception("Start failed") mock_client_class.return_value = mock_client - + # Test with custom identifier custom_id = "error.codeinterpreter.v1" interpreter = AgentCoreCodeInterpreter(region="us-west-2", identifier=custom_id) - + action = InitSessionAction(type="initSession", description="Test session", session_name="error-session") - + result = interpreter.init_session(action) - + assert result["status"] == "error" assert "Failed to initialize session 'error-session': Start failed" in result["content"][0]["text"] - + # Verify error logging includes identifier information mock_logger.error.assert_called_once_with( f"Failed to initialize session 'error-session' with identifier: {custom_id}. Error: Start failed" @@ -370,43 +370,43 @@ def test_init_session_multiple_identifiers_verification(mock_client_class, mock_ with patch("strands_tools.code_interpreter.agent_core_code_interpreter.resolve_region") as mock_resolve: mock_resolve.return_value = "us-west-2" mock_client_class.return_value = mock_client - + # Create first interpreter with custom identifier custom_id1 = "first.codeinterpreter.v1" interpreter1 = AgentCoreCodeInterpreter(region="us-west-2", identifier=custom_id1) - + # Create second interpreter with different custom identifier custom_id2 = "second.codeinterpreter.v1" interpreter2 = AgentCoreCodeInterpreter(region="us-west-2", identifier=custom_id2) - + # Create third interpreter with default identifier interpreter3 = AgentCoreCodeInterpreter(region="us-west-2") - + # Test first interpreter action1 = InitSessionAction(type="initSession", description="First session", session_name="session1") result1 = interpreter1.init_session(action1) assert result1["status"] == "success" - + # Test second interpreter action2 = InitSessionAction(type="initSession", description="Second session", session_name="session2") result2 = interpreter2.init_session(action2) assert result2["status"] == "success" - + # Test third interpreter action3 = InitSessionAction(type="initSession", description="Third session", session_name="session3") result3 = interpreter3.init_session(action3) assert result3["status"] == "success" - + # Verify each interpreter used its correct identifier assert mock_client.start.call_count == 3 call_args_list = mock_client.start.call_args_list - + # First call should use custom_id1 assert call_args_list[0] == ((), {"identifier": custom_id1}) - + # Second call should use custom_id2 assert call_args_list[1] == ((), {"identifier": custom_id2}) - + # Third call should use default identifier assert call_args_list[2] == ((), {"identifier": "aws.codeinterpreter.v1"}) @@ -438,7 +438,7 @@ def test_init_session_client_start_exception(mock_client_class, interpreter, moc action = InitSessionAction(type="initSession", description="Test session", session_name="fail-session") result = interpreter.init_session(action) - + assert result["status"] == "error" assert "Failed to initialize session 'fail-session': Start failed" in result["content"][0]["text"] diff --git a/tests/test_retrieve.py b/tests/test_retrieve.py index 6cdaba08..8273f8ae 100644 --- a/tests/test_retrieve.py +++ b/tests/test_retrieve.py @@ -116,6 +116,22 @@ def test_format_results_for_display(): empty_formatted = retrieve.format_results_for_display([]) assert empty_formatted == "No results found above score threshold." + # Test with s3Location + s3_results = [ + { + "content": {"text": "S3 content", "type": "TEXT"}, + "location": { + "s3Location": {"uri": "s3://bucket/key/document.pdf"}, + "type": "S3", + }, + "score": 0.88, + } + ] + s3_formatted = retrieve.format_results_for_display(s3_results) + assert "Score: 0.8800" in s3_formatted + assert "Document ID: s3://bucket/key/document.pdf" in s3_formatted + assert "Content: S3 content" in s3_formatted + def test_retrieve_tool_direct(mock_boto3_client): """Test direct invocation of the retrieve tool."""