From a959deeac6322d4cf865ebfe5f83649ecff2c5d9 Mon Sep 17 00:00:00 2001 From: Zoltan Lux Date: Mon, 3 Jun 2024 23:00:00 +0200 Subject: [PATCH] Use Gemini without API key (#2805) * google default auth and svc keyfile for Gemini * [.Net] Release note for 0.0.14 (#2815) * update release note * update trigger * [.Net] Update website for AutoGen.SemanticKernel and AutoGen.Ollama (#2814) support vertex ai compute region * [CAP] User supplied threads for agents (#2812) * First pass: message loop in main thread * pypi version bump * Fix readme * Better example * Fixed docs * pre-commit fixes * refactoring, minor fixes, update gemini demo ipynb * add new deps again and reset line endings * Docstring for the init function. Use private methods * improve docstring --------- Co-authored-by: Xiaoyun Zhang Co-authored-by: Rajan Co-authored-by: Zoltan Lux --- autogen/oai/gemini.py | 282 +++++++++++------- setup.py | 2 +- test/oai/test_gemini.py | 13 +- .../non-openai-models/cloud-gemini.ipynb | 18 +- 4 files changed, 208 insertions(+), 107 deletions(-) diff --git a/autogen/oai/gemini.py b/autogen/oai/gemini.py index 5c06a4def0c9..60a2062bb89c 100644 --- a/autogen/oai/gemini.py +++ b/autogen/oai/gemini.py @@ -42,12 +42,16 @@ import google.generativeai as genai import requests +import vertexai from google.ai.generativelanguage import Content, Part from google.api_core.exceptions import InternalServerError from openai.types.chat import ChatCompletion from openai.types.chat.chat_completion import ChatCompletionMessage, Choice from openai.types.completion_usage import CompletionUsage from PIL import Image +from vertexai.generative_models import Content as VertexAIContent +from vertexai.generative_models import GenerativeModel +from vertexai.generative_models import Part as VertexAIPart class GeminiClient: @@ -68,14 +72,49 @@ class GeminiClient: "max_output_tokens": "max_output_tokens", } + def _initialize_vartexai(self, **params): + if "google_application_credentials" in params: + # Path to JSON Keyfile + os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = params["google_application_credentials"] + vertexai_init_args = {} + if "project_id" in params: + vertexai_init_args["project"] = params["project_id"] + if "location" in params: + vertexai_init_args["location"] = params["location"] + if vertexai_init_args: + vertexai.init(**vertexai_init_args) + def __init__(self, **kwargs): + """Uses either either api_key for authentication from the LLM config + (specifying the GOOGLE_API_KEY environment variable also works), + or follows the Google authentication mechanism for VertexAI in Google Cloud if no api_key is specified, + where project_id and location can also be passed as parameters. Service account key file can also be used. + If neither a service account key file, nor the api_key are passed, then the default credentials will be used, + which could be a personal account if the user is already authenticated in, like in Google Cloud Shell. + + Args: + api_key (str): The API key for using Gemini. + google_application_credentials (str): Path to the JSON service account key file of the service account. + Alternatively, the GOOGLE_APPLICATION_CREDENTIALS environment variable + can also be set instead of using this argument. + project_id (str): Google Cloud project id, which is only valid in case no API key is specified. + location (str): Compute region to be used, like 'us-west1'. + This parameter is only valid in case no API key is specified. + """ self.api_key = kwargs.get("api_key", None) if not self.api_key: self.api_key = os.getenv("GOOGLE_API_KEY") - - assert ( - self.api_key - ), "Please provide api_key in your config list entry for Gemini or set the GOOGLE_API_KEY env variable." + if self.api_key is None: + self.use_vertexai = True + self._initialize_vartexai(**kwargs) + else: + self.use_vertexai = False + else: + self.use_vertexai = False + if not self.use_vertexai: + assert ("project_id" not in kwargs) and ( + "location" not in kwargs + ), "Google Cloud project and compute location cannot be set when using an API Key!" def message_retrieval(self, response) -> List: """ @@ -102,6 +141,12 @@ def get_usage(response) -> Dict: } def create(self, params: Dict) -> ChatCompletion: + if self.use_vertexai: + self._initialize_vartexai(**params) + else: + assert ("project_id" not in params) and ( + "location" not in params + ), "Google Cloud project and compute location cannot be set when using an API Key!" model_name = params.get("model", "gemini-pro") if not model_name: raise ValueError( @@ -133,13 +178,17 @@ def create(self, params: Dict) -> ChatCompletion: if "vision" not in model_name: # A. create and call the chat model. - gemini_messages = oai_messages_to_gemini_messages(messages) - - # we use chat model by default - model = genai.GenerativeModel( - model_name, generation_config=generation_config, safety_settings=safety_settings - ) - genai.configure(api_key=self.api_key) + gemini_messages = self._oai_messages_to_gemini_messages(messages) + if self.use_vertexai: + model = GenerativeModel( + model_name, generation_config=generation_config, safety_settings=safety_settings + ) + else: + # we use chat model by default + model = genai.GenerativeModel( + model_name, generation_config=generation_config, safety_settings=safety_settings + ) + genai.configure(api_key=self.api_key) chat = model.start_chat(history=gemini_messages[:-1]) max_retries = 5 for attempt in range(max_retries): @@ -167,14 +216,19 @@ def create(self, params: Dict) -> ChatCompletion: completion_tokens = model.count_tokens(ans).total_tokens elif model_name == "gemini-pro-vision": # B. handle the vision model + if self.use_vertexai: + model = GenerativeModel( + model_name, generation_config=generation_config, safety_settings=safety_settings + ) + else: + model = genai.GenerativeModel( + model_name, generation_config=generation_config, safety_settings=safety_settings + ) + genai.configure(api_key=self.api_key) # Gemini's vision model does not support chat history yet - model = genai.GenerativeModel( - model_name, generation_config=generation_config, safety_settings=safety_settings - ) - genai.configure(api_key=self.api_key) # chat = model.start_chat(history=gemini_messages[:-1]) # response = chat.send_message(gemini_messages[-1]) - user_message = oai_content_to_gemini_content(messages[-1]["content"]) + user_message = self._oai_content_to_gemini_content(messages[-1]["content"]) if len(messages) > 2: warnings.warn( "Warning: Gemini's vision model does not support chat history yet.", @@ -184,7 +238,10 @@ def create(self, params: Dict) -> ChatCompletion: response = model.generate_content(user_message, stream=stream) # ans = response.text - ans: str = response._result.candidates[0].content.parts[0].text + if self.use_vertexai: + ans: str = response.candidates[0].content.parts[0].text + else: + ans: str = response._result.candidates[0].content.parts[0].text prompt_tokens = model.count_tokens(user_message).total_tokens completion_tokens = model.count_tokens(ans).total_tokens @@ -209,99 +266,111 @@ def create(self, params: Dict) -> ChatCompletion: return response_oai + def _oai_content_to_gemini_content(self, content: Union[str, List]) -> List: + """Convert content from OAI format to Gemini format""" + rst = [] + if isinstance(content, str): + if self.use_vertexai: + rst.append(VertexAIPart.from_text(content)) + else: + rst.append(Part(text=content)) + return rst + + assert isinstance(content, list) + + for msg in content: + if isinstance(msg, dict): + assert "type" in msg, f"Missing 'type' field in message: {msg}" + if msg["type"] == "text": + if self.use_vertexai: + rst.append(VertexAIPart.from_text(text=msg["text"])) + else: + rst.append(Part(text=msg["text"])) + elif msg["type"] == "image_url": + if self.use_vertexai: + img_url = msg["image_url"]["url"] + re.match(r"data:image/(?:png|jpeg);base64,", img_url) + img = get_image_data(img_url, use_b64=False) + # image/png works with jpeg as well + img_part = VertexAIPart.from_data(img, mime_type="image/png") + rst.append(img_part) + else: + b64_img = get_image_data(msg["image_url"]["url"]) + img = _to_pil(b64_img) + rst.append(img) + else: + raise ValueError(f"Unsupported message type: {msg['type']}") + else: + raise ValueError(f"Unsupported message type: {type(msg)}") + return rst -def calculate_gemini_cost(input_tokens: int, output_tokens: int, model_name: str) -> float: - if "1.5" in model_name or "gemini-experimental" in model_name: - # "gemini-1.5-pro-preview-0409" - # Cost is $7 per million input tokens and $21 per million output tokens - return 7.0 * input_tokens / 1e6 + 21.0 * output_tokens / 1e6 - - if "gemini-pro" not in model_name and "gemini-1.0-pro" not in model_name: - warnings.warn(f"Cost calculation is not implemented for model {model_name}. Using Gemini-1.0-Pro.", UserWarning) - - # Cost is $0.5 per million input tokens and $1.5 per million output tokens - return 0.5 * input_tokens / 1e6 + 1.5 * output_tokens / 1e6 - + def _concat_parts(self, parts: List[Part]) -> List: + """Concatenate parts with the same type. + If two adjacent parts both have the "text" attribute, then it will be joined into one part. + """ + if not parts: + return [] -def oai_content_to_gemini_content(content: Union[str, List]) -> List: - """Convert content from OAI format to Gemini format""" - rst = [] - if isinstance(content, str): - rst.append(Part(text=content)) - return rst + concatenated_parts = [] + previous_part = parts[0] - assert isinstance(content, list) - - for msg in content: - if isinstance(msg, dict): - assert "type" in msg, f"Missing 'type' field in message: {msg}" - if msg["type"] == "text": - rst.append(Part(text=msg["text"])) - elif msg["type"] == "image_url": - b64_img = get_image_data(msg["image_url"]["url"]) - img = _to_pil(b64_img) - rst.append(img) + for current_part in parts[1:]: + if previous_part.text != "": + if self.use_vertexai: + previous_part = VertexAIPart.from_text(previous_part.text + current_part.text) + else: + previous_part.text += current_part.text else: - raise ValueError(f"Unsupported message type: {msg['type']}") - else: - raise ValueError(f"Unsupported message type: {type(msg)}") - return rst + concatenated_parts.append(previous_part) + previous_part = current_part + if previous_part.text == "": + if self.use_vertexai: + previous_part = VertexAIPart.from_text("empty") + else: + previous_part.text = "empty" # Empty content is not allowed. + concatenated_parts.append(previous_part) -def concat_parts(parts: List[Part]) -> List: - """Concatenate parts with the same type. - If two adjacent parts both have the "text" attribute, then it will be joined into one part. - """ - if not parts: - return [] + return concatenated_parts - concatenated_parts = [] - previous_part = parts[0] + def _oai_messages_to_gemini_messages(self, messages: list[Dict[str, Any]]) -> list[dict[str, Any]]: + """Convert messages from OAI format to Gemini format. + Make sure the "user" role and "model" role are interleaved. + Also, make sure the last item is from the "user" role. + """ + prev_role = None + rst = [] + curr_parts = [] + for i, message in enumerate(messages): + parts = self._oai_content_to_gemini_content(message["content"]) + role = "user" if message["role"] in ["user", "system"] else "model" + + if prev_role is None or role == prev_role: + curr_parts += parts + elif role != prev_role: + if self.use_vertexai: + rst.append(VertexAIContent(parts=self._concat_parts(curr_parts), role=prev_role)) + else: + rst.append(Content(parts=curr_parts, role=prev_role)) + prev_role = role - for current_part in parts[1:]: - if previous_part.text != "": - previous_part.text += current_part.text + # handle the last message + if self.use_vertexai: + rst.append(VertexAIContent(parts=self._concat_parts(curr_parts), role=role)) else: - concatenated_parts.append(previous_part) - previous_part = current_part - - if previous_part.text == "": - previous_part.text = "empty" # Empty content is not allowed. - concatenated_parts.append(previous_part) - - return concatenated_parts - - -def oai_messages_to_gemini_messages(messages: list[Dict[str, Any]]) -> list[dict[str, Any]]: - """Convert messages from OAI format to Gemini format. - Make sure the "user" role and "model" role are interleaved. - Also, make sure the last item is from the "user" role. - """ - prev_role = None - rst = [] - curr_parts = [] - for i, message in enumerate(messages): - parts = oai_content_to_gemini_content(message["content"]) - role = "user" if message["role"] in ["user", "system"] else "model" - - if prev_role is None or role == prev_role: - curr_parts += parts - elif role != prev_role: - rst.append(Content(parts=concat_parts(curr_parts), role=prev_role)) - curr_parts = parts - prev_role = role - - # handle the last message - rst.append(Content(parts=concat_parts(curr_parts), role=role)) - - # The Gemini is restrict on order of roles, such that - # 1. The messages should be interleaved between user and model. - # 2. The last message must be from the user role. - # We add a dummy message "continue" if the last role is not the user. - if rst[-1].role != "user": - rst.append(Content(parts=oai_content_to_gemini_content("continue"), role="user")) + rst.append(Content(parts=curr_parts, role=role)) + + # The Gemini is restrict on order of roles, such that + # 1. The messages should be interleaved between user and model. + # 2. The last message must be from the user role. + # We add a dummy message "continue" if the last role is not the user. + if rst[-1].role != "user": + if self.use_vertexai: + rst.append(VertexAIContent(parts=self._oai_content_to_gemini_content("continue"), role="user")) + else: + rst.append(Content(parts=self._oai_content_to_gemini_content("continue"), role="user")) - return rst + return rst def _to_pil(data: str) -> Image.Image: @@ -336,3 +405,16 @@ def get_image_data(image_file: str, use_b64=True) -> bytes: return base64.b64encode(content).decode("utf-8") else: return content + + +def calculate_gemini_cost(input_tokens: int, output_tokens: int, model_name: str) -> float: + if "1.5" in model_name or "gemini-experimental" in model_name: + # "gemini-1.5-pro-preview-0409" + # Cost is $7 per million input tokens and $21 per million output tokens + return 7.0 * input_tokens / 1e6 + 21.0 * output_tokens / 1e6 + + if "gemini-pro" not in model_name and "gemini-1.0-pro" not in model_name: + warnings.warn(f"Cost calculation is not implemented for model {model_name}. Using Gemini-1.0-Pro.", UserWarning) + + # Cost is $0.5 per million input tokens and $1.5 per million output tokens + return 0.5 * input_tokens / 1e6 + 1.5 * output_tokens / 1e6 diff --git a/setup.py b/setup.py index b3a868327507..c24b810e3174 100644 --- a/setup.py +++ b/setup.py @@ -80,7 +80,7 @@ "teachable": ["chromadb"], "lmm": ["replicate", "pillow"], "graph": ["networkx", "matplotlib"], - "gemini": ["google-generativeai>=0.5,<1", "pillow", "pydantic"], + "gemini": ["google-generativeai>=0.5,<1", "google-cloud-aiplatform", "google-auth", "pillow", "pydantic"], "websurfer": ["beautifulsoup4", "markdownify", "pdfminer.six", "pathvalidate"], "redis": ["redis"], "cosmosdb": ["azure-cosmos>=4.2.0"], diff --git a/test/oai/test_gemini.py b/test/oai/test_gemini.py index 7161d605fb6d..4f77d288789a 100644 --- a/test/oai/test_gemini.py +++ b/test/oai/test_gemini.py @@ -33,11 +33,18 @@ def gemini_client(): return GeminiClient(api_key="fake_api_key") -# Test initialization and configuration +# Test compute location initialization and configuration @pytest.mark.skipif(skip, reason="Google GenAI dependency is not installed") -def test_initialization(): +def test_compute_location_initialization(): with pytest.raises(AssertionError): - GeminiClient() # Should raise an AssertionError due to missing API key + GeminiClient( + api_key="fake_api_key", location="us-west1" + ) # Should raise an AssertionError due to specifying API key and compute location + + +@pytest.fixture +def gemini_google_auth_default_client(): + return GeminiClient() @pytest.mark.skipif(skip, reason="Google GenAI dependency is not installed") diff --git a/website/docs/topics/non-openai-models/cloud-gemini.ipynb b/website/docs/topics/non-openai-models/cloud-gemini.ipynb index a794b8552e5f..da773e0d4472 100644 --- a/website/docs/topics/non-openai-models/cloud-gemini.ipynb +++ b/website/docs/topics/non-openai-models/cloud-gemini.ipynb @@ -24,11 +24,13 @@ "\n", "## Features\n", "\n", - "There's no need to handle OpenAI or Google's GenAI packages separately; AutoGen manages all of these for you. You can easily create different agents with various backend LLMs using the assistant agent. All models and agents are readily accessible at your fingertips.\n", + "There's no need to handle OpenAI or Google's GenAI packages separately; AutoGen manages all of these for you. You can easily create different agents with various backend LLMs using the assistant agent. All models and agents are readily accessible at your fingertips. \n", + " \n", "\n", "## Main Distinctions\n", "\n", - "- Currently, Gemini does not include a \"system_message\" field. However, you can incorporate this instruction into the first message of your interaction." + "- Currently, Gemini does not include a \"system_message\" field. However, you can incorporate this instruction into the first message of your interaction.\n", + "- If no API key is specified for Gemini, then authentication will happen using the default google auth mechanism for Google Cloud. Service accounts are also supported, where the JSON key file has to be provided." ] }, { @@ -57,6 +59,16 @@ " \"api_type\": \"google\"\n", " },\n", " {\n", + " \"model\": \"gemini-1.5-pro-001\",\n", + " \"api_type\": \"google\"\n", + " },\n", + " {\n", + " \"model\": \"gemini-1.5-pro\",\n", + " \"project\": \"your-awesome-google-cloud-project-id\",\n", + " \"location\": \"us-west1\",\n", + " \"google_application_credentials\": \"your-google-service-account-key.json\"\n", + " },\n", + " {\n", " \"model\": \"gemini-pro-vision\",\n", " \"api_key\": \"your Google's GenAI Key goes here\",\n", " \"api_type\": \"google\"\n", @@ -110,7 +122,7 @@ "config_list_gemini = autogen.config_list_from_json(\n", " \"OAI_CONFIG_LIST\",\n", " filter_dict={\n", - " \"model\": [\"gemini-pro\"],\n", + " \"model\": [\"gemini-pro\", \"gemini-1.5-pro\", \"gemini-1.5-pro-001\"],\n", " },\n", ")\n", "\n",