Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions lib/crewai/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,9 @@ litellm = [
boto3 = [
"boto3>=1.40.45",
]
google-genai = [
"google-genai>=1.2.0",
]


[project.scripts]
Expand Down
107 changes: 86 additions & 21 deletions lib/crewai/src/crewai/llms/providers/gemini/completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@


try:
from google import genai # type: ignore
from google.genai import types # type: ignore
from google.genai.errors import APIError # type: ignore
from google import genai
from google.genai import types
from google.genai.errors import APIError
except ImportError:
raise ImportError(
"Google Gen AI native provider not available, to install: `uv add google-genai`"
Expand All @@ -40,6 +40,7 @@ def __init__(
stop_sequences: list[str] | None = None,
stream: bool = False,
safety_settings: dict[str, Any] | None = None,
client_params: dict[str, Any] | None = None,
**kwargs,
):
"""Initialize Google Gemini chat completion client.
Expand All @@ -56,35 +57,27 @@ def __init__(
stop_sequences: Stop sequences
stream: Enable streaming responses
safety_settings: Safety filter settings
client_params: Additional parameters to pass to the Google Gen AI Client constructor.
Supports parameters like http_options, credentials, debug_config, etc.
**kwargs: Additional parameters
"""
super().__init__(
model=model, temperature=temperature, stop=stop_sequences or [], **kwargs
)

# Get API configuration
# Store client params for later use
self.client_params = client_params or {}

# Get API configuration with environment variable fallbacks
self.api_key = (
api_key or os.getenv("GOOGLE_API_KEY") or os.getenv("GEMINI_API_KEY")
)
self.project = project or os.getenv("GOOGLE_CLOUD_PROJECT")
self.location = location or os.getenv("GOOGLE_CLOUD_LOCATION") or "us-central1"

# Initialize client based on available configuration
if self.project:
# Use Vertex AI
self.client = genai.Client(
vertexai=True,
project=self.project,
location=self.location,
)
elif self.api_key:
# Use Gemini Developer API
self.client = genai.Client(api_key=self.api_key)
else:
raise ValueError(
"Either GOOGLE_API_KEY/GEMINI_API_KEY (for Gemini API) or "
"GOOGLE_CLOUD_PROJECT (for Vertex AI) must be set"
)
use_vertexai = os.getenv("GOOGLE_GENAI_USE_VERTEXAI", "").lower() == "true"

self.client = self._initialize_client(use_vertexai)

# Store completion parameters
self.top_p = top_p
Expand All @@ -99,6 +92,78 @@ def __init__(
self.is_gemini_1_5 = "gemini-1.5" in model.lower()
self.supports_tools = self.is_gemini_1_5 or self.is_gemini_2

def _initialize_client(self, use_vertexai: bool = False) -> genai.Client:
"""Initialize the Google Gen AI client with proper parameter handling.

Args:
use_vertexai: Whether to use Vertex AI (from environment variable)

Returns:
Initialized Google Gen AI Client
"""
client_params = {}

if self.client_params:
client_params.update(self.client_params)

if use_vertexai or self.project:
client_params.update(
{
"vertexai": True,
"project": self.project,
"location": self.location,
}
)

client_params.pop("api_key", None)

elif self.api_key:
client_params["api_key"] = self.api_key

client_params.pop("vertexai", None)
client_params.pop("project", None)
client_params.pop("location", None)

else:
try:
return genai.Client(**client_params)
except Exception as e:
raise ValueError(
"Either GOOGLE_API_KEY/GEMINI_API_KEY (for Gemini API) or "
"GOOGLE_CLOUD_PROJECT (for Vertex AI) must be set"
) from e

return genai.Client(**client_params)

def _get_client_params(self) -> dict[str, Any]:
"""Get client parameters for compatibility with base class.

Note: This method is kept for compatibility but the Google Gen AI SDK
uses a different initialization pattern via the Client constructor.
"""
params = {}

if (
hasattr(self, "client")
and hasattr(self.client, "vertexai")
and self.client.vertexai
):
# Vertex AI configuration
params.update(
{
"vertexai": True,
"project": self.project,
"location": self.location,
}
)
elif self.api_key:
params["api_key"] = self.api_key

if self.client_params:
params.update(self.client_params)

return params

def call(
self,
messages: str | list[dict[str, str]],
Expand Down Expand Up @@ -427,7 +492,7 @@ def supports_function_calling(self) -> bool:

def supports_stop_words(self) -> bool:
"""Check if the model supports stop words."""
return self._supports_stop_words_implementation()
return True

def get_context_window_size(self) -> int:
"""Get the context window size for the model."""
Expand Down
Loading