diff --git a/lib/crewai/pyproject.toml b/lib/crewai/pyproject.toml index bdd9ee4dc0..01e6d7a91e 100644 --- a/lib/crewai/pyproject.toml +++ b/lib/crewai/pyproject.toml @@ -85,6 +85,9 @@ voyageai = [ litellm = [ "litellm>=1.74.9", ] +boto3 = [ + "boto3>=1.40.45", +] [project.scripts] diff --git a/lib/crewai/src/crewai/llm.py b/lib/crewai/src/crewai/llm.py index c3a759bae1..b5fc653cbc 100644 --- a/lib/crewai/src/crewai/llm.py +++ b/lib/crewai/src/crewai/llm.py @@ -367,6 +367,14 @@ def _get_native_provider(cls, provider: str) -> type | None: except ImportError: return None + elif provider == "bedrock": + try: + from crewai.llms.providers.bedrock.completion import BedrockCompletion + + return BedrockCompletion + except ImportError: + return None + return None def __init__( diff --git a/lib/crewai/src/crewai/llms/providers/bedrock/__init__.py b/lib/crewai/src/crewai/llms/providers/bedrock/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/lib/crewai/src/crewai/llms/providers/bedrock/completion.py b/lib/crewai/src/crewai/llms/providers/bedrock/completion.py new file mode 100644 index 0000000000..615d8838d1 --- /dev/null +++ b/lib/crewai/src/crewai/llms/providers/bedrock/completion.py @@ -0,0 +1,553 @@ +from collections.abc import Mapping, Sequence +import logging +import os +from typing import Any + +from crewai.events.types.llm_events import LLMCallType +from crewai.llms.base_llm import BaseLLM +from crewai.utilities.agent_utils import is_context_length_exceeded +from crewai.utilities.exceptions.context_window_exceeding_exception import ( + LLMContextLengthExceededError, +) + + +try: + from boto3.session import Session + from botocore.config import Config + from botocore.exceptions import BotoCoreError, ClientError +except ImportError: + raise ImportError( + "AWS Bedrock native provider not available, to install: `uv add boto3`" + ) from None + + +class BedrockCompletion(BaseLLM): + """AWS Bedrock native completion implementation using the Converse API. + + This class provides direct integration with AWS Bedrock using the modern + Converse API, which provides a unified interface across all Bedrock models. + """ + + def __init__( + self, + model: str = "anthropic.claude-3-5-sonnet-20241022-v2:0", + aws_access_key_id: str | None = None, + aws_secret_access_key: str | None = None, + aws_session_token: str | None = None, + region_name: str = "us-east-1", + temperature: float | None = None, + max_tokens: int | None = None, + top_p: float | None = None, + top_k: int | None = None, + stop_sequences: Sequence[str] | None = None, + stream: bool = False, + **kwargs, + ): + """Initialize AWS Bedrock completion client.""" + # Extract provider from kwargs to avoid duplicate argument + kwargs.pop("provider", None) + + super().__init__( + model=model, + temperature=temperature, + stop=stop_sequences or [], + provider="bedrock", + **kwargs, + ) + + # Initialize Bedrock client with proper configuration + session = Session( + aws_access_key_id=aws_access_key_id or os.getenv("AWS_ACCESS_KEY_ID"), + aws_secret_access_key=aws_secret_access_key + or os.getenv("AWS_SECRET_ACCESS_KEY"), + aws_session_token=aws_session_token or os.getenv("AWS_SESSION_TOKEN"), + region_name=region_name, + ) + + # Configure client with timeouts and retries following AWS best practices + config = Config( + connect_timeout=60, + read_timeout=300, + retries={ + "max_attempts": 3, + "mode": "adaptive", + }, + tcp_keepalive=True, + ) + + self.client = session.client("bedrock-runtime", config=config) + self.region_name = region_name + + # Store completion parameters + self.max_tokens = max_tokens + self.top_p = top_p + self.top_k = top_k + self.stream = stream + self.stop_sequences = stop_sequences or [] + + # Model-specific settings + self.is_claude_model = "claude" in model.lower() + self.supports_tools = True # Converse API supports tools for most models + self.supports_streaming = True + + # Handle inference profiles for newer models + self.model_id = model + + def call( + self, + messages: str | list[dict[str, str]], + tools: Sequence[Mapping[str, Any]] | None = None, + callbacks: list[Any] | None = None, + available_functions: dict[str, Any] | None = None, + from_task: Any | None = None, + from_agent: Any | None = None, + ) -> str | Any: + """Call AWS Bedrock Converse API.""" + try: + # Emit call started event + self._emit_call_started_event( + messages=messages, + tools=tools, + callbacks=callbacks, + available_functions=available_functions, + from_task=from_task, + from_agent=from_agent, + ) + + # Format messages for Converse API + formatted_messages, system_message = self._format_messages_for_converse( + messages + ) + + # Prepare tool configuration + tool_config = None + if tools: + tool_config = {"tools": self._format_tools_for_converse(tools)} + + # Prepare request body + body = { + "inferenceConfig": self._get_inference_config(), + } + + # Add system message if present + if system_message: + body["system"] = [{"text": system_message}] + + # Add tool config if present + if tool_config: + body["toolConfig"] = tool_config + + if self.stream: + return self._handle_streaming_converse( + formatted_messages, body, available_functions, from_task, from_agent + ) + + return self._handle_converse( + formatted_messages, body, available_functions, from_task, from_agent + ) + + except Exception as e: + if is_context_length_exceeded(e): + logging.error(f"Context window exceeded: {e}") + raise LLMContextLengthExceededError(str(e)) from e + + error_msg = f"AWS Bedrock API call failed: {e!s}" + logging.error(error_msg) + self._emit_call_failed_event( + error=error_msg, from_task=from_task, from_agent=from_agent + ) + raise + + def _handle_converse( + self, + messages: list[dict[str, Any]], + body: dict[str, Any], + available_functions: Mapping[str, Any] | None = None, + from_task: Any | None = None, + from_agent: Any | None = None, + ) -> str: + """Handle non-streaming converse API call following AWS best practices.""" + try: + # Validate messages format before API call + if not messages: + raise ValueError("Messages cannot be empty") + + # Ensure we have valid message structure + for i, msg in enumerate(messages): + if ( + not isinstance(msg, dict) + or "role" not in msg + or "content" not in msg + ): + raise ValueError(f"Invalid message format at index {i}") + + # Call Bedrock Converse API with proper error handling + response = self.client.converse( + modelId=self.model_id, messages=messages, **body + ) + + # Track token usage according to AWS response format + if "usage" in response: + self._track_token_usage_internal(response["usage"]) + + # Extract content following AWS response structure + output = response.get("output", {}) + message = output.get("message", {}) + content = message.get("content", []) + + if not content: + logging.warning("No content in Bedrock response") + return ( + "I apologize, but I received an empty response. Please try again." + ) + + # Extract text content from response + text_content = "" + for content_block in content: + # Handle different content block types as per AWS documentation + if "text" in content_block: + text_content += content_block["text"] + elif content_block.get("type") == "toolUse" and available_functions: + # Handle tool use according to AWS format + tool_use = content_block["toolUse"] + function_name = tool_use.get("name") + function_args = tool_use.get("input", {}) + + result = self._handle_tool_execution( + function_name=function_name, + function_args=function_args, + available_functions=available_functions, + from_task=from_task, + from_agent=from_agent, + ) + + if result is not None: + return result + + # Apply stop sequences if configured + text_content = self._apply_stop_words(text_content) + + # Validate final response + if not text_content or text_content.strip() == "": + logging.warning("Extracted empty text content from Bedrock response") + text_content = "I apologize, but I couldn't generate a proper response. Please try again." + + self._emit_call_completed_event( + response=text_content, + call_type=LLMCallType.LLM_CALL, + from_task=from_task, + from_agent=from_agent, + messages=messages, + ) + + return text_content + + except ClientError as e: + # Handle all AWS ClientError exceptions as per documentation + error_code = e.response.get("Error", {}).get("Code", "Unknown") + error_msg = e.response.get("Error", {}).get("Message", str(e)) + + # Log the specific error for debugging + logging.error(f"AWS Bedrock ClientError ({error_code}): {error_msg}") + + # Handle specific error codes as documented + if error_code == "ValidationException": + # This is the error we're seeing with Cohere + if "last turn" in error_msg and "user message" in error_msg: + raise ValueError( + f"Conversation format error: {error_msg}. Check message alternation." + ) from e + raise ValueError(f"Request validation failed: {error_msg}") from e + if error_code == "AccessDeniedException": + raise PermissionError( + f"Access denied to model {self.model_id}: {error_msg}" + ) from e + if error_code == "ResourceNotFoundException": + raise ValueError(f"Model {self.model_id} not found: {error_msg}") from e + if error_code == "ThrottlingException": + raise RuntimeError( + f"API throttled, please retry later: {error_msg}" + ) from e + if error_code == "ModelTimeoutException": + raise TimeoutError(f"Model request timed out: {error_msg}") from e + if error_code == "ServiceQuotaExceededException": + raise RuntimeError(f"Service quota exceeded: {error_msg}") from e + if error_code == "ModelNotReadyException": + raise RuntimeError( + f"Model {self.model_id} not ready: {error_msg}" + ) from e + if error_code == "ModelErrorException": + raise RuntimeError(f"Model error: {error_msg}") from e + if error_code == "InternalServerException": + raise RuntimeError(f"Internal server error: {error_msg}") from e + if error_code == "ServiceUnavailableException": + raise RuntimeError(f"Service unavailable: {error_msg}") from e + + raise RuntimeError(f"Bedrock API error ({error_code}): {error_msg}") from e + + except BotoCoreError as e: + error_msg = f"Bedrock connection error: {e}" + logging.error(error_msg) + raise ConnectionError(error_msg) from e + except Exception as e: + # Catch any other unexpected errors + error_msg = f"Unexpected error in Bedrock converse call: {e}" + logging.error(error_msg) + raise RuntimeError(error_msg) from e + + def _handle_streaming_converse( + self, + messages: list[dict[str, Any]], + body: dict[str, Any], + available_functions: dict[str, Any] | None = None, + from_task: Any | None = None, + from_agent: Any | None = None, + ) -> str: + """Handle streaming converse API call.""" + full_response = "" + + try: + response = self.client.converse_stream( + modelId=self.model_id, messages=messages, **body + ) + + stream = response.get("stream") + if stream: + for event in stream: + if "contentBlockDelta" in event: + delta = event["contentBlockDelta"]["delta"] + if "text" in delta: + text_chunk = delta["text"] + logging.debug(f"Streaming text chunk: {text_chunk[:50]}...") + full_response += text_chunk + self._emit_stream_chunk_event( + chunk=text_chunk, + from_task=from_task, + from_agent=from_agent, + ) + elif "messageStop" in event: + # Handle end of message + break + + except ClientError as e: + error_msg = self._handle_client_error(e) + raise RuntimeError(error_msg) from e + except BotoCoreError as e: + error_msg = f"Bedrock streaming connection error: {e}" + logging.error(error_msg) + raise ConnectionError(error_msg) from e + + # Apply stop words to full response + full_response = self._apply_stop_words(full_response) + + # Ensure we don't return empty content + if not full_response or full_response.strip() == "": + logging.warning("Bedrock streaming returned empty content, using fallback") + full_response = ( + "I apologize, but I couldn't generate a response. Please try again." + ) + + # Emit completion event + self._emit_call_completed_event( + response=full_response, + call_type=LLMCallType.LLM_CALL, + from_task=from_task, + from_agent=from_agent, + messages=messages, + ) + + return full_response + + def _format_messages_for_converse( + self, messages: str | list[dict[str, str]] + ) -> tuple[list[dict[str, Any]], str | None]: + """Format messages for Converse API following AWS documentation.""" + # Use base class formatting first + formatted_messages = self._format_messages(messages) + + converse_messages = [] + system_message = None + + for message in formatted_messages: + role = message.get("role") + content = message.get("content", "") + + if role == "system": + # Extract system message - Converse API handles it separately + if system_message: + system_message += f"\n\n{content}" + else: + system_message = content + else: + # Convert to Converse API format with proper content structure + converse_messages.append({"role": role, "content": [{"text": content}]}) + + # CRITICAL: Handle model-specific conversation requirements + # Cohere and some other models require conversation to end with user message + if converse_messages: + last_message = converse_messages[-1] + if last_message["role"] == "assistant": + # For Cohere models, add a continuation user message + if "cohere" in self.model.lower(): + converse_messages.append( + { + "role": "user", + "content": [ + { + "text": "Please continue and provide your final answer." + } + ], + } + ) + # For other models that might have similar requirements + elif any( + model_family in self.model.lower() + for model_family in ["command", "coral"] + ): + converse_messages.append( + { + "role": "user", + "content": [{"text": "Continue your response."}], + } + ) + + # Ensure first message is from user (required by Converse API) + if not converse_messages: + converse_messages.append( + { + "role": "user", + "content": [{"text": "Hello, please help me with my request."}], + } + ) + elif converse_messages[0]["role"] != "user": + converse_messages.insert( + 0, + { + "role": "user", + "content": [{"text": "Hello, please help me with my request."}], + }, + ) + + return converse_messages, system_message + + def _format_tools_for_converse(self, tools: list[dict]) -> list[dict]: + """Convert CrewAI tools to Converse API format following AWS specification.""" + from crewai.llms.providers.utils.common import safe_tool_conversion + + converse_tools = [] + + for tool in tools: + try: + name, description, parameters = safe_tool_conversion(tool, "Bedrock") + + converse_tool = { + "toolSpec": { + "name": name, + "description": description, + } + } + + if parameters and isinstance(parameters, dict): + converse_tool["toolSpec"]["inputSchema"] = {"json": parameters} + + converse_tools.append(converse_tool) + + except Exception as e: # noqa: PERF203 + logging.warning( + f"Failed to convert tool {tool.get('name', 'unknown')}: {e}" + ) + continue + + return converse_tools + + def _get_inference_config(self) -> dict[str, Any]: + """Get inference configuration following AWS Converse API specification.""" + config = {} + + if self.max_tokens: + config["maxTokens"] = self.max_tokens + + if self.temperature is not None: + config["temperature"] = float(self.temperature) + if self.top_p is not None: + config["topP"] = float(self.top_p) + if self.stop_sequences: + config["stopSequences"] = self.stop_sequences + + if self.is_claude_model and self.top_k is not None: + # top_k is supported by Claude models + config["topK"] = int(self.top_k) + + return config + + def _handle_client_error(self, e: ClientError) -> str: + """Handle AWS ClientError with specific error codes and return error message.""" + error_code = e.response.get("Error", {}).get("Code", "Unknown") + error_msg = e.response.get("Error", {}).get("Message", str(e)) + + error_mapping = { + "AccessDeniedException": f"Access denied to model {self.model_id}: {error_msg}", + "ResourceNotFoundException": f"Model {self.model_id} not found: {error_msg}", + "ThrottlingException": f"API throttled, please retry later: {error_msg}", + "ValidationException": f"Invalid request: {error_msg}", + "ModelTimeoutException": f"Model request timed out: {error_msg}", + "ServiceQuotaExceededException": f"Service quota exceeded: {error_msg}", + "ModelNotReadyException": f"Model {self.model_id} not ready: {error_msg}", + "ModelErrorException": f"Model error: {error_msg}", + } + + full_error_msg = error_mapping.get( + error_code, f"Bedrock API error: {error_msg}" + ) + logging.error(f"Bedrock client error ({error_code}): {full_error_msg}") + + return full_error_msg + + def _track_token_usage_internal(self, usage: dict[str, Any]) -> None: + """Track token usage from Bedrock response.""" + input_tokens = usage.get("inputTokens", 0) + output_tokens = usage.get("outputTokens", 0) + total_tokens = usage.get("totalTokens", input_tokens + output_tokens) + + self._token_usage["prompt_tokens"] += input_tokens + self._token_usage["completion_tokens"] += output_tokens + self._token_usage["total_tokens"] += total_tokens + self._token_usage["successful_requests"] += 1 + + def supports_function_calling(self) -> bool: + """Check if the model supports function calling.""" + return self.supports_tools + + def supports_stop_words(self) -> bool: + """Check if the model supports stop words.""" + return True + + def get_context_window_size(self) -> int: + """Get the context window size for the model.""" + from crewai.llm import CONTEXT_WINDOW_USAGE_RATIO + + # Context window sizes for common Bedrock models + context_windows = { + "anthropic.claude-3-5-sonnet": 200000, + "anthropic.claude-3-5-haiku": 200000, + "anthropic.claude-3-opus": 200000, + "anthropic.claude-3-sonnet": 200000, + "anthropic.claude-3-haiku": 200000, + "anthropic.claude-3-7-sonnet": 200000, + "anthropic.claude-v2": 100000, + "amazon.titan-text-express": 8000, + "ai21.j2-ultra": 8192, + "cohere.command-text": 4096, + "meta.llama2-13b-chat": 4096, + "meta.llama2-70b-chat": 4096, + "meta.llama3-70b-instruct": 128000, + "deepseek.r1": 32768, + } + + # Find the best match for the model name + for model_prefix, size in context_windows.items(): + if self.model.startswith(model_prefix): + return int(size * CONTEXT_WINDOW_USAGE_RATIO) + + # Default context window size + return int(8192 * CONTEXT_WINDOW_USAGE_RATIO) diff --git a/uv.lock b/uv.lock index 78ff815155..bfcd012842 100644 --- a/uv.lock +++ b/uv.lock @@ -1020,6 +1020,9 @@ aisuite = [ aws = [ { name = "boto3" }, ] +boto3 = [ + { name = "boto3" }, +] docling = [ { name = "docling" }, ] @@ -1060,6 +1063,7 @@ requires-dist = [ { name = "appdirs", specifier = ">=1.4.4" }, { name = "blinker", specifier = ">=1.9.0" }, { name = "boto3", marker = "extra == 'aws'", specifier = ">=1.40.38" }, + { name = "boto3", marker = "extra == 'boto3'", specifier = ">=1.40.45" }, { name = "chromadb", specifier = "~=1.1.0" }, { name = "click", specifier = ">=8.1.7" }, { name = "crewai-tools", marker = "extra == 'tools'", editable = "lib/crewai-tools" }, @@ -1095,7 +1099,7 @@ requires-dist = [ { name = "uv", specifier = ">=0.4.25" }, { name = "voyageai", marker = "extra == 'voyageai'", specifier = ">=0.3.5" }, ] -provides-extras = ["aisuite", "aws", "docling", "embeddings", "litellm", "mem0", "openpyxl", "pandas", "pdfplumber", "qdrant", "tools", "voyageai", "watson"] +provides-extras = ["aisuite", "aws", "boto3", "docling", "embeddings", "litellm", "mem0", "openpyxl", "pandas", "pdfplumber", "qdrant", "tools", "voyageai", "watson"] [[package]] name = "crewai-devtools"