-
-
Notifications
You must be signed in to change notification settings - Fork 532
feat(ai): add Google Gemini LLM provider support #3967
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,18 +1,24 @@ | ||
| """Factory function to get the configured embedder.""" | ||
|
|
||
| from django.conf import settings | ||
|
|
||
| from apps.ai.embeddings.base import Embedder | ||
| from apps.ai.embeddings.google import GoogleEmbedder | ||
| from apps.ai.embeddings.openai import OpenAIEmbedder | ||
|
|
||
|
|
||
| def get_embedder() -> Embedder: | ||
| """Get the configured embedder. | ||
|
|
||
| Currently returns OpenAI embedder, but can be extended to support | ||
| Currently returns OpenAI and Google embedder, but can be extended to support | ||
| other providers (e.g., Anthropic, Cohere, etc.). | ||
|
|
||
| Returns: | ||
| Embedder instance configured for the current provider. | ||
|
|
||
| """ | ||
| # Currently OpenAI, but can be extended to support other providers | ||
| # Currently OpenAI and Google, but can be extended to support other providers | ||
| if settings.LLM_PROVIDER == "google": | ||
| return GoogleEmbedder() | ||
|
|
||
| return OpenAIEmbedder() |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,142 @@ | ||
| """Google implementation of embedder.""" | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| try: | ||
| from google import genai | ||
| except ImportError: | ||
| # Fallback to deprecated package if new one not available | ||
| try: | ||
| import warnings | ||
|
|
||
| import google.generativeai as genai | ||
|
|
||
| warnings.warn( | ||
| ( | ||
| "google.generativeai is deprecated. " | ||
| "Please install google-genai package: pip install google-genai" | ||
| ), | ||
| DeprecationWarning, | ||
| stacklevel=2, | ||
| ) | ||
| except ImportError: | ||
| genai = None | ||
|
|
||
| import requests | ||
| from django.conf import settings | ||
|
|
||
| from apps.ai.embeddings.base import Embedder | ||
|
|
||
|
|
||
| class GoogleEmbedder(Embedder): | ||
| """Google implementation of embedder using Google Generative AI SDK.""" | ||
|
|
||
| def __init__(self, model: str = "gemini-embedding-001") -> None: | ||
| """Initialize Google embedder. | ||
|
|
||
| Args: | ||
| model: The Google embedding model to use. | ||
| Default: gemini-embedding-001 (recommended, 768 dimensions) | ||
| Note: text-embedding-004 is deprecated | ||
|
|
||
| """ | ||
| self.api_key = settings.GOOGLE_API_KEY | ||
| self.model = model | ||
| # gemini-embedding-001 has 768 dimensions | ||
| self._dimensions = 768 | ||
|
|
||
| # Use Google Generative AI SDK (preferred method) | ||
| # The SDK handles endpoint URLs and authentication automatically | ||
| if genai: | ||
| genai.configure(api_key=self.api_key) | ||
| self.use_sdk = True | ||
| else: | ||
| # Fallback to REST API (not recommended - use SDK instead) | ||
| self.base_url = "https://generativelanguage.googleapis.com/v1beta" | ||
| self.use_sdk = False | ||
| import warnings | ||
|
|
||
| warnings.warn( | ||
| "Google GenAI SDK not available. Install it with: pip install google-genai", | ||
| UserWarning, | ||
| stacklevel=2, | ||
| ) | ||
|
|
||
| def embed_query(self, text: str) -> list[float]: | ||
| """Generate embedding for a query string. | ||
|
|
||
| Args: | ||
| text: The query text to embed. | ||
|
|
||
| Returns: | ||
| List of floats representing the embedding vector. | ||
|
|
||
| """ | ||
| if self.use_sdk and genai: | ||
| # Use Google Generative AI SDK (preferred method) | ||
| # SDK automatically handles the correct endpoint and model format | ||
| result = genai.embed_content( | ||
| model=self.model, | ||
| content=text, | ||
| ) | ||
| # SDK returns embedding in 'embedding' key | ||
| return result["embedding"] | ||
|
|
||
| # Fallback to REST API | ||
| endpoint = f"{self.base_url}/models/{self.model}:embedContent?key={self.api_key}" | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. P1: Security: API key is passed as a URL query parameter, which can leak into server/proxy logs and error reports. Use the Prompt for AI agents |
||
| response = requests.post( | ||
| endpoint, | ||
| headers={"Content-Type": "application/json"}, | ||
| json={ | ||
| "content": {"parts": [{"text": text}]}, | ||
| }, | ||
| timeout=30, | ||
| ) | ||
| response.raise_for_status() | ||
| data = response.json() | ||
| return data["embedding"]["values"] | ||
|
|
||
| def embed_documents(self, texts: list[str]) -> list[list[float]]: | ||
| """Generate embeddings for multiple documents. | ||
|
|
||
| Args: | ||
| texts: List of document texts to embed. | ||
|
|
||
| Returns: | ||
| List of embedding vectors, one per document. | ||
|
|
||
| """ | ||
| if self.use_sdk and genai: | ||
| # Use Google Generative AI SDK (preferred method) | ||
| # SDK handles batching automatically | ||
| results = [] | ||
| for text in texts: | ||
| result = genai.embed_content( | ||
| model=self.model, | ||
| content=text, | ||
| ) | ||
| results.append(result["embedding"]) | ||
| return results | ||
|
|
||
| # Fallback to REST API | ||
| endpoint = f"{self.base_url}/models/{self.model}:batchEmbedContents?key={self.api_key}" | ||
| response = requests.post( | ||
| endpoint, | ||
| headers={"Content-Type": "application/json"}, | ||
| json={ | ||
| "requests": [{"content": {"parts": [{"text": text}]}} for text in texts], | ||
| }, | ||
| timeout=60, | ||
| ) | ||
| response.raise_for_status() | ||
| data = response.json() | ||
| return [item["embedding"]["values"] for item in data["embeddings"]] | ||
|
|
||
| def get_dimensions(self) -> int: | ||
| """Get the dimension of embeddings produced by this embedder. | ||
|
|
||
| Returns: | ||
| Integer representing the embedding dimension. | ||
|
|
||
| """ | ||
| return self._dimensions | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -3,15 +3,16 @@ | |
| import os | ||
| from unittest.mock import Mock, patch | ||
|
|
||
| import pytest | ||
|
|
||
| from apps.ai.common.llm_config import get_llm | ||
|
|
||
|
|
||
| class TestLLMConfig: | ||
| """Test cases for LLM configuration.""" | ||
|
|
||
| @patch.dict(os.environ, {"LLM_PROVIDER": "openai", "DJANGO_OPEN_AI_SECRET_KEY": "test-key"}) | ||
| @patch.dict( | ||
| os.environ, | ||
| {"DJANGO_LLM_PROVIDER": "openai", "DJANGO_OPEN_AI_SECRET_KEY": "test-key"}, | ||
| ) | ||
| @patch("apps.ai.common.llm_config.LLM") | ||
| def test_get_llm_openai_default(self, mock_llm): | ||
| """Test getting OpenAI LLM with default model.""" | ||
|
|
@@ -21,7 +22,7 @@ def test_get_llm_openai_default(self, mock_llm): | |
| result = get_llm() | ||
|
|
||
| mock_llm.assert_called_once_with( | ||
| model="gpt-4.1-mini", | ||
| model="gpt-4o-mini", | ||
| api_key="test-key", | ||
| temperature=0.1, | ||
| ) | ||
|
|
@@ -30,9 +31,9 @@ def test_get_llm_openai_default(self, mock_llm): | |
| @patch.dict( | ||
| os.environ, | ||
| { | ||
| "LLM_PROVIDER": "openai", | ||
| "DJANGO_LLM_PROVIDER": "openai", | ||
| "DJANGO_OPEN_AI_SECRET_KEY": "test-key", | ||
| "OPENAI_MODEL_NAME": "gpt-4", | ||
| "DJANGO_OPENAI_MODEL_NAME": "gpt-4", | ||
| }, | ||
| ) | ||
| @patch("apps.ai.common.llm_config.LLM") | ||
|
|
@@ -53,50 +54,73 @@ def test_get_llm_openai_custom_model(self, mock_llm): | |
| @patch.dict( | ||
| os.environ, | ||
| { | ||
| "LLM_PROVIDER": "anthropic", | ||
| "ANTHROPIC_API_KEY": "test-anthropic-key", | ||
| "DJANGO_LLM_PROVIDER": "unsupported", | ||
| "DJANGO_OPEN_AI_SECRET_KEY": "test-key", | ||
| }, | ||
| ) | ||
| @patch("apps.ai.common.llm_config.logger") | ||
| @patch("apps.ai.common.llm_config.LLM") | ||
| def test_get_llm_anthropic_default(self, mock_llm): | ||
| """Test getting Anthropic LLM with default model.""" | ||
| def test_get_llm_unsupported_provider(self, mock_llm, mock_logger): | ||
| """Test getting LLM with unsupported provider logs warning and falls back to OpenAI.""" | ||
| mock_llm_instance = Mock() | ||
| mock_llm.return_value = mock_llm_instance | ||
|
|
||
| result = get_llm() | ||
|
|
||
| # Should log warning about unrecognized provider | ||
| mock_logger.warning.assert_called_once() | ||
| # Should fallback to OpenAI | ||
| mock_llm.assert_called_once_with( | ||
| model="claude-3-5-sonnet-20241022", | ||
| api_key="test-anthropic-key", | ||
| model="gpt-4o-mini", | ||
| api_key="test-key", | ||
| temperature=0.1, | ||
| ) | ||
| assert result == mock_llm_instance | ||
|
|
||
| @patch.dict( | ||
| os.environ, | ||
| { | ||
| "LLM_PROVIDER": "anthropic", | ||
| "ANTHROPIC_API_KEY": "test-anthropic-key", | ||
| "ANTHROPIC_MODEL_NAME": "claude-3-opus", | ||
| "DJANGO_LLM_PROVIDER": "google", | ||
| "DJANGO_GOOGLE_API_KEY": "test-google-key", | ||
| "DJANGO_GOOGLE_MODEL_NAME": "gemini-2.0-flash", | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. P2: This test claims to verify the default Google model, but it explicitly provides Prompt for AI agents |
||
| }, | ||
| ) | ||
| @patch("apps.ai.common.llm_config.LLM") | ||
| def test_get_llm_anthropic_custom_model(self, mock_llm): | ||
| """Test getting Anthropic LLM with custom model.""" | ||
| def test_get_llm_google(self, mock_llm): | ||
| """Test getting Google LLM with default model.""" | ||
| mock_llm_instance = Mock() | ||
| mock_llm.return_value = mock_llm_instance | ||
|
|
||
| result = get_llm() | ||
|
|
||
| mock_llm.assert_called_once_with( | ||
| model="claude-3-opus", | ||
| api_key="test-anthropic-key", | ||
| model="gemini-2.0-flash", | ||
| base_url="https://generativelanguage.googleapis.com/v1beta/openai/", | ||
| api_key="test-google-key", | ||
| temperature=0.1, | ||
| ) | ||
| assert result == mock_llm_instance | ||
|
|
||
| @patch.dict(os.environ, {"LLM_PROVIDER": "unsupported"}) | ||
| def test_get_llm_unsupported_provider(self): | ||
| """Test getting LLM with unsupported provider raises error.""" | ||
| with pytest.raises(ValueError, match="Unsupported LLM provider: unsupported"): | ||
| get_llm() | ||
| @patch.dict( | ||
| os.environ, | ||
| { | ||
| "DJANGO_LLM_PROVIDER": "google", | ||
| "DJANGO_GOOGLE_API_KEY": "test-google-key", | ||
| "DJANGO_GOOGLE_MODEL_NAME": "gemini-pro", | ||
| }, | ||
| ) | ||
| @patch("apps.ai.common.llm_config.LLM") | ||
| def test_get_llm_google_custom_model(self, mock_llm): | ||
| """Test getting Google LLM with custom model.""" | ||
| mock_llm_instance = Mock() | ||
| mock_llm.return_value = mock_llm_instance | ||
|
|
||
| result = get_llm() | ||
|
|
||
| mock_llm.assert_called_once_with( | ||
| model="gemini-pro", | ||
| base_url="https://generativelanguage.googleapis.com/v1beta/openai/", | ||
| api_key="test-google-key", | ||
| temperature=0.1, | ||
| ) | ||
| assert result == mock_llm_instance | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
P0: Bug: The new
google.genaiSDK doesn't havegenai.configure()orgenai.embed_content()— these are from the deprecatedgoogle.generativeaipackage. The new SDK uses a client-based API:client = genai.Client(api_key=...)andclient.models.embed_content(...). This code will raiseAttributeErrorat runtime.Initialize a client in
__init__and useself.client.models.embed_content(...)in the embed methods, consistent with howOpenAIEmbedderusesself.client.Prompt for AI agents