Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use Gemini without API key #2805

Merged
merged 14 commits into from
Jun 3, 2024
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
266 changes: 166 additions & 100 deletions autogen/oai/gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,16 @@

import google.generativeai as genai
import requests
import vertexai
luxzoli marked this conversation as resolved.
Show resolved Hide resolved
from google.ai.generativelanguage import Content, Part
from google.api_core.exceptions import InternalServerError
from openai.types.chat import ChatCompletion
from openai.types.chat.chat_completion import ChatCompletionMessage, Choice
from openai.types.completion_usage import CompletionUsage
from PIL import Image
from vertexai.generative_models import Content as VertexAIContent
from vertexai.generative_models import GenerativeModel
from vertexai.generative_models import Part as VertexAIPart


class GeminiClient:
Expand All @@ -68,14 +72,33 @@ class GeminiClient:
"max_output_tokens": "max_output_tokens",
}

def initialize_vartexai(self, **params):
luxzoli marked this conversation as resolved.
Show resolved Hide resolved
if "google_application_credentials" in params:
# Path to JSON Keyfile
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = params["google_application_credentials"]
vertexai_init_args = {}
if "project_id" in params:
vertexai_init_args["project"] = params["project_id"]
if "location" in params:
vertexai_init_args["location"] = params["location"]
if vertexai_init_args:
vertexai.init(**vertexai_init_args)

def __init__(self, **kwargs):
self.api_key = kwargs.get("api_key", None)
if not self.api_key:
self.api_key = os.getenv("GOOGLE_API_KEY")

assert (
self.api_key
), "Please provide api_key in your config list entry for Gemini or set the GOOGLE_API_KEY env variable."
if self.api_key is None:
luxzoli marked this conversation as resolved.
Show resolved Hide resolved
self.use_vertexai = True
self.initialize_vartexai(**kwargs)
else:
self.use_vertexai = False
else:
self.use_vertexai = False
if not self.use_vertexai:
assert ("project_id" not in kwargs) and (
"location" not in kwargs
), "Google Cloud project and compute location cannot be set when using an API Key!"

def message_retrieval(self, response) -> List:
"""
Expand All @@ -102,6 +125,12 @@ def get_usage(response) -> Dict:
}

def create(self, params: Dict) -> ChatCompletion:
if self.use_vertexai:
self.initialize_vartexai(**params)
else:
assert ("project_id" not in params) and (
"location" not in params
), "Google Cloud project and compute location cannot be set when using an API Key!"
model_name = params.get("model", "gemini-pro")
if not model_name:
raise ValueError(
Expand Down Expand Up @@ -133,13 +162,17 @@ def create(self, params: Dict) -> ChatCompletion:

if "vision" not in model_name:
# A. create and call the chat model.
gemini_messages = oai_messages_to_gemini_messages(messages)

# we use chat model by default
model = genai.GenerativeModel(
model_name, generation_config=generation_config, safety_settings=safety_settings
)
genai.configure(api_key=self.api_key)
gemini_messages = self.oai_messages_to_gemini_messages(messages)
if self.use_vertexai:
model = GenerativeModel(
model_name, generation_config=generation_config, safety_settings=safety_settings
)
else:
# we use chat model by default
model = genai.GenerativeModel(
model_name, generation_config=generation_config, safety_settings=safety_settings
)
genai.configure(api_key=self.api_key)
chat = model.start_chat(history=gemini_messages[:-1])
max_retries = 5
for attempt in range(max_retries):
Expand Down Expand Up @@ -167,14 +200,19 @@ def create(self, params: Dict) -> ChatCompletion:
completion_tokens = model.count_tokens(ans).total_tokens
elif model_name == "gemini-pro-vision":
# B. handle the vision model
if self.use_vertexai:
model = GenerativeModel(
model_name, generation_config=generation_config, safety_settings=safety_settings
)
else:
model = genai.GenerativeModel(
model_name, generation_config=generation_config, safety_settings=safety_settings
)
genai.configure(api_key=self.api_key)
# Gemini's vision model does not support chat history yet
model = genai.GenerativeModel(
model_name, generation_config=generation_config, safety_settings=safety_settings
)
genai.configure(api_key=self.api_key)
# chat = model.start_chat(history=gemini_messages[:-1])
# response = chat.send_message(gemini_messages[-1])
user_message = oai_content_to_gemini_content(messages[-1]["content"])
user_message = self.oai_content_to_gemini_content(messages[-1]["content"])
if len(messages) > 2:
warnings.warn(
"Warning: Gemini's vision model does not support chat history yet.",
Expand All @@ -184,7 +222,10 @@ def create(self, params: Dict) -> ChatCompletion:

response = model.generate_content(user_message, stream=stream)
# ans = response.text
ans: str = response._result.candidates[0].content.parts[0].text
if self.use_vertexai:
ans: str = response.candidates[0].content.parts[0].text
else:
ans: str = response._result.candidates[0].content.parts[0].text

prompt_tokens = model.count_tokens(user_message).total_tokens
completion_tokens = model.count_tokens(ans).total_tokens
Expand All @@ -209,99 +250,111 @@ def create(self, params: Dict) -> ChatCompletion:

return response_oai

def oai_content_to_gemini_content(self, content: Union[str, List]) -> List:
luxzoli marked this conversation as resolved.
Show resolved Hide resolved
"""Convert content from OAI format to Gemini format"""
rst = []
if isinstance(content, str):
if self.use_vertexai:
rst.append(VertexAIPart.from_text(content))
else:
rst.append(Part(text=content))
return rst

assert isinstance(content, list)

for msg in content:
if isinstance(msg, dict):
assert "type" in msg, f"Missing 'type' field in message: {msg}"
if msg["type"] == "text":
if self.use_vertexai:
rst.append(VertexAIPart.from_text(text=msg["text"]))
else:
rst.append(Part(text=msg["text"]))
elif msg["type"] == "image_url":
if self.use_vertexai:
img_url = msg["image_url"]["url"]
re.match(r"data:image/(?:png|jpeg);base64,", img_url)
img = get_image_data(img_url, use_b64=False)
# image/png works with jpeg as well
img_part = VertexAIPart.from_data(img, mime_type="image/png")
rst.append(img_part)
else:
b64_img = get_image_data(msg["image_url"]["url"])
img = _to_pil(b64_img)
rst.append(img)
else:
raise ValueError(f"Unsupported message type: {msg['type']}")
else:
raise ValueError(f"Unsupported message type: {type(msg)}")
return rst

def calculate_gemini_cost(input_tokens: int, output_tokens: int, model_name: str) -> float:
if "1.5" in model_name or "gemini-experimental" in model_name:
# "gemini-1.5-pro-preview-0409"
# Cost is $7 per million input tokens and $21 per million output tokens
return 7.0 * input_tokens / 1e6 + 21.0 * output_tokens / 1e6

if "gemini-pro" not in model_name and "gemini-1.0-pro" not in model_name:
warnings.warn(f"Cost calculation is not implemented for model {model_name}. Using Gemini-1.0-Pro.", UserWarning)

# Cost is $0.5 per million input tokens and $1.5 per million output tokens
return 0.5 * input_tokens / 1e6 + 1.5 * output_tokens / 1e6

def concat_parts(self, parts: List[Part]) -> List:
luxzoli marked this conversation as resolved.
Show resolved Hide resolved
"""Concatenate parts with the same type.
If two adjacent parts both have the "text" attribute, then it will be joined into one part.
"""
if not parts:
return []

def oai_content_to_gemini_content(content: Union[str, List]) -> List:
"""Convert content from OAI format to Gemini format"""
rst = []
if isinstance(content, str):
rst.append(Part(text=content))
return rst
concatenated_parts = []
previous_part = parts[0]

assert isinstance(content, list)

for msg in content:
if isinstance(msg, dict):
assert "type" in msg, f"Missing 'type' field in message: {msg}"
if msg["type"] == "text":
rst.append(Part(text=msg["text"]))
elif msg["type"] == "image_url":
b64_img = get_image_data(msg["image_url"]["url"])
img = _to_pil(b64_img)
rst.append(img)
for current_part in parts[1:]:
if previous_part.text != "":
if self.use_vertexai:
previous_part = VertexAIPart.from_text(previous_part.text + current_part.text)
else:
previous_part.text += current_part.text
else:
raise ValueError(f"Unsupported message type: {msg['type']}")
else:
raise ValueError(f"Unsupported message type: {type(msg)}")
return rst
concatenated_parts.append(previous_part)
previous_part = current_part

if previous_part.text == "":
if self.use_vertexai:
previous_part = VertexAIPart.from_text("empty")
else:
previous_part.text = "empty" # Empty content is not allowed.
concatenated_parts.append(previous_part)

def concat_parts(parts: List[Part]) -> List:
"""Concatenate parts with the same type.
If two adjacent parts both have the "text" attribute, then it will be joined into one part.
"""
if not parts:
return []
return concatenated_parts

concatenated_parts = []
previous_part = parts[0]
def oai_messages_to_gemini_messages(self, messages: list[Dict[str, Any]]) -> list[dict[str, Any]]:
luxzoli marked this conversation as resolved.
Show resolved Hide resolved
"""Convert messages from OAI format to Gemini format.
Make sure the "user" role and "model" role are interleaved.
Also, make sure the last item is from the "user" role.
"""
prev_role = None
rst = []
curr_parts = []
for i, message in enumerate(messages):
parts = self.oai_content_to_gemini_content(message["content"])
role = "user" if message["role"] in ["user", "system"] else "model"

if prev_role is None or role == prev_role:
curr_parts += parts
elif role != prev_role:
if self.use_vertexai:
rst.append(VertexAIContent(parts=self.concat_parts(curr_parts), role=prev_role))
else:
rst.append(Content(parts=curr_parts, role=prev_role))
prev_role = role

for current_part in parts[1:]:
if previous_part.text != "":
previous_part.text += current_part.text
# handle the last message
if self.use_vertexai:
rst.append(VertexAIContent(parts=self.concat_parts(curr_parts), role=role))
else:
concatenated_parts.append(previous_part)
previous_part = current_part

if previous_part.text == "":
previous_part.text = "empty" # Empty content is not allowed.
concatenated_parts.append(previous_part)

return concatenated_parts


def oai_messages_to_gemini_messages(messages: list[Dict[str, Any]]) -> list[dict[str, Any]]:
"""Convert messages from OAI format to Gemini format.
Make sure the "user" role and "model" role are interleaved.
Also, make sure the last item is from the "user" role.
"""
prev_role = None
rst = []
curr_parts = []
for i, message in enumerate(messages):
parts = oai_content_to_gemini_content(message["content"])
role = "user" if message["role"] in ["user", "system"] else "model"

if prev_role is None or role == prev_role:
curr_parts += parts
elif role != prev_role:
rst.append(Content(parts=concat_parts(curr_parts), role=prev_role))
curr_parts = parts
prev_role = role

# handle the last message
rst.append(Content(parts=concat_parts(curr_parts), role=role))

# The Gemini is restrict on order of roles, such that
# 1. The messages should be interleaved between user and model.
# 2. The last message must be from the user role.
# We add a dummy message "continue" if the last role is not the user.
if rst[-1].role != "user":
rst.append(Content(parts=oai_content_to_gemini_content("continue"), role="user"))
rst.append(Content(parts=curr_parts, role=role))

# The Gemini is restrict on order of roles, such that
# 1. The messages should be interleaved between user and model.
# 2. The last message must be from the user role.
# We add a dummy message "continue" if the last role is not the user.
if rst[-1].role != "user":
if self.use_vertexai:
rst.append(VertexAIContent(parts=self.oai_content_to_gemini_content("continue"), role="user"))
else:
rst.append(Content(parts=self.oai_content_to_gemini_content("continue"), role="user"))

return rst
return rst


def _to_pil(data: str) -> Image.Image:
Expand Down Expand Up @@ -336,3 +389,16 @@ def get_image_data(image_file: str, use_b64=True) -> bytes:
return base64.b64encode(content).decode("utf-8")
else:
return content


def calculate_gemini_cost(input_tokens: int, output_tokens: int, model_name: str) -> float:
if "1.5" in model_name or "gemini-experimental" in model_name:
# "gemini-1.5-pro-preview-0409"
# Cost is $7 per million input tokens and $21 per million output tokens
return 7.0 * input_tokens / 1e6 + 21.0 * output_tokens / 1e6

if "gemini-pro" not in model_name and "gemini-1.0-pro" not in model_name:
warnings.warn(f"Cost calculation is not implemented for model {model_name}. Using Gemini-1.0-Pro.", UserWarning)

# Cost is $0.5 per million input tokens and $1.5 per million output tokens
return 0.5 * input_tokens / 1e6 + 1.5 * output_tokens / 1e6
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@
"teachable": ["chromadb"],
"lmm": ["replicate", "pillow"],
"graph": ["networkx", "matplotlib"],
"gemini": ["google-generativeai>=0.5,<1", "pillow", "pydantic"],
"gemini": ["google-generativeai>=0.5,<1", "google-cloud-aiplatform", "google-auth", "pillow", "pydantic"],
"websurfer": ["beautifulsoup4", "markdownify", "pdfminer.six", "pathvalidate"],
"redis": ["redis"],
"cosmosdb": ["azure-cosmos>=4.2.0"],
Expand Down
13 changes: 10 additions & 3 deletions test/oai/test_gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,18 @@ def gemini_client():
return GeminiClient(api_key="fake_api_key")


# Test initialization and configuration
# Test compute location initialization and configuration
@pytest.mark.skipif(skip, reason="Google GenAI dependency is not installed")
def test_initialization():
def test_compute_location_initialization():
with pytest.raises(AssertionError):
GeminiClient() # Should raise an AssertionError due to missing API key
GeminiClient(
api_key="fake_api_key", location="us-west1"
) # Should raise an AssertionError due to specifying API key and compute location


@pytest.fixture
def gemini_google_auth_default_client():
return GeminiClient()


@pytest.mark.skipif(skip, reason="Google GenAI dependency is not installed")
Expand Down
Loading