diff --git a/autogen/oai/gemini.py b/autogen/oai/gemini.py index fcf7e09c025d..5c06a4def0c9 100644 --- a/autogen/oai/gemini.py +++ b/autogen/oai/gemini.py @@ -5,8 +5,18 @@ llm_config={ "config_list": [{ "api_type": "google", - "model": "models/gemini-pro", - "api_key": os.environ.get("GOOGLE_API_KEY") + "model": "gemini-pro", + "api_key": os.environ.get("GOOGLE_API_KEY"), + "safety_settings": [ + {"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_ONLY_HIGH"}, + {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_ONLY_HIGH"}, + {"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_ONLY_HIGH"}, + {"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_ONLY_HIGH"} + ], + "top_p":0.5, + "max_tokens": 2048, + "temperature": 1.0, + "top_k": 5 } ]} @@ -47,6 +57,17 @@ class GeminiClient: of AutoGen. """ + # Mapping, where Key is a term used by Autogen, and Value is a term used by Gemini + PARAMS_MAPPING = { + "max_tokens": "max_output_tokens", + # "n": "candidate_count", # Gemini supports only `n=1` + "stop_sequences": "stop_sequences", + "temperature": "temperature", + "top_p": "top_p", + "top_k": "top_k", + "max_output_tokens": "max_output_tokens", + } + def __init__(self, **kwargs): self.api_key = kwargs.get("api_key", None) if not self.api_key: @@ -93,12 +114,15 @@ def create(self, params: Dict) -> ChatCompletion: messages = params.get("messages", []) stream = params.get("stream", False) n_response = params.get("n", 1) - params.get("temperature", 0.5) - params.get("top_p", 1.0) - params.get("max_tokens", 4096) + + generation_config = { + gemini_term: params[autogen_term] + for autogen_term, gemini_term in self.PARAMS_MAPPING.items() + if autogen_term in params + } + safety_settings = params.get("safety_settings", {}) if stream: - # warn user that streaming is not supported warnings.warn( "Streaming is not supported for Gemini yet, and it will have no effect. Please set stream=False.", UserWarning, @@ -112,7 +136,9 @@ def create(self, params: Dict) -> ChatCompletion: gemini_messages = oai_messages_to_gemini_messages(messages) # we use chat model by default - model = genai.GenerativeModel(model_name) + model = genai.GenerativeModel( + model_name, generation_config=generation_config, safety_settings=safety_settings + ) genai.configure(api_key=self.api_key) chat = model.start_chat(history=gemini_messages[:-1]) max_retries = 5 @@ -142,7 +168,9 @@ def create(self, params: Dict) -> ChatCompletion: elif model_name == "gemini-pro-vision": # B. handle the vision model # Gemini's vision model does not support chat history yet - model = genai.GenerativeModel(model_name) + model = genai.GenerativeModel( + model_name, generation_config=generation_config, safety_settings=safety_settings + ) genai.configure(api_key=self.api_key) # chat = model.start_chat(history=gemini_messages[:-1]) # response = chat.send_message(gemini_messages[-1])