Skip to content

Commit

Permalink
Add Gemini support with Google Default Auth
Browse files Browse the repository at this point in the history
  • Loading branch information
luxzoli committed May 28, 2024
1 parent 69c9765 commit 01aa5e1
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 13 deletions.
35 changes: 28 additions & 7 deletions autogen/oai/gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,10 @@ def get_usage(response) -> Dict:
}

def create(self, params: Dict) -> ChatCompletion:
if self.api_key is None:
self.use_vertexai = True
else:
self.use_vertexai = False
model_name = params.get("model", "gemini-pro")
if not model_name:
raise ValueError(
Expand Down Expand Up @@ -138,7 +142,9 @@ def create(self, params: Dict) -> ChatCompletion:
# A. create and call the chat model.
gemini_messages = self.oai_messages_to_gemini_messages(messages)
if self.use_vertexai:
model = GenerativeModel(model_name)
model = GenerativeModel(
model_name, generation_config=generation_config, safety_settings=safety_settings
)
else:
# we use chat model by default
model = genai.GenerativeModel(
Expand Down Expand Up @@ -173,7 +179,9 @@ def create(self, params: Dict) -> ChatCompletion:
elif model_name == "gemini-pro-vision":
# B. handle the vision model
if self.use_vertexai:
model = GenerativeModel(model_name)
model = GenerativeModel(
model_name, generation_config=generation_config, safety_settings=safety_settings
)
else:
model = genai.GenerativeModel(
model_name, generation_config=generation_config, safety_settings=safety_settings
Expand All @@ -192,7 +200,10 @@ def create(self, params: Dict) -> ChatCompletion:

response = model.generate_content(user_message, stream=stream)
# ans = response.text
ans: str = response._result.candidates[0].content.parts[0].text
if self.use_vertexai:
ans: str = response.candidates[0].content.parts[0].text
else:
ans: str = response._result.candidates[0].content.parts[0].text

prompt_tokens = model.count_tokens(user_message).total_tokens
completion_tokens = model.count_tokens(ans).total_tokens
Expand Down Expand Up @@ -233,11 +244,21 @@ def oai_content_to_gemini_content(self, content: Union[str, List]) -> List:
if isinstance(msg, dict):
assert "type" in msg, f"Missing 'type' field in message: {msg}"
if msg["type"] == "text":
rst.append(Part(text=msg["text"]))
if self.use_vertexai:
rst.append(VertexAIPart.from_text(text=msg["text"]))
else:
rst.append(Part(text=msg["text"]))
elif msg["type"] == "image_url":
b64_img = get_image_data(msg["image_url"]["url"])
img = _to_pil(b64_img)
rst.append(img)
if self.use_vertexai:
img = get_image_data(msg["image_url"]["url"], use_b64=False)
# img = _to_pil(b64_img)
# img_part = VertexAIPart.from_image(img)
img_part = VertexAIPart.from_data(img, mime_type="image/png")
rst.append(img_part)
else:
b64_img = get_image_data(msg["image_url"]["url"])
img = _to_pil(b64_img)
rst.append(img)
else:
raise ValueError(f"Unsupported message type: {msg['type']}")
else:
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@
"teachable": ["chromadb"],
"lmm": ["replicate", "pillow"],
"graph": ["networkx", "matplotlib"],
"gemini": ["google-generativeai>=0.5,<1", "pillow", "pydantic"],
"gemini": ["google-generativeai>=0.5,<1", "google-cloud-aiplatform", "google-auth", "pillow", "pydantic"],
"websurfer": ["beautifulsoup4", "markdownify", "pdfminer.six", "pathvalidate"],
"redis": ["redis"],
"cosmosdb": ["azure-cosmos>=4.2.0"],
Expand Down
8 changes: 3 additions & 5 deletions test/oai/test_gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,9 @@ def gemini_client():
return GeminiClient(api_key="fake_api_key")


# Test initialization and configuration
@pytest.mark.skipif(skip, reason="Google GenAI dependency is not installed")
def test_initialization():
with pytest.raises(AssertionError):
GeminiClient() # Should raise an AssertionError due to missing API key
@pytest.fixture
def gemini_google_auth_default_client():
return GeminiClient()


@pytest.mark.skipif(skip, reason="Google GenAI dependency is not installed")
Expand Down

0 comments on commit 01aa5e1

Please sign in to comment.