Skip to content

Commit

Permalink
added Gemini safety setting and Gemini generation config (#2429)
Browse files Browse the repository at this point in the history
* added Gemini safety setting and Gemini generation config

* define params_mapping as a constant as a class variable

* fixed formatting issues

---------

Co-authored-by: nikolay tolstov <[email protected]>
Co-authored-by: Chi Wang <[email protected]>
Co-authored-by: Eric Zhu <[email protected]>
  • Loading branch information
4 people authored and victordibia committed Jul 30, 2024
1 parent e92daa1 commit 51e1ea8
Showing 1 changed file with 36 additions and 8 deletions.
44 changes: 36 additions & 8 deletions autogen/oai/gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,18 @@
llm_config={
"config_list": [{
"api_type": "google",
"model": "models/gemini-pro",
"api_key": os.environ.get("GOOGLE_API_KEY")
"model": "gemini-pro",
"api_key": os.environ.get("GOOGLE_API_KEY"),
"safety_settings": [
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_ONLY_HIGH"},
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_ONLY_HIGH"},
{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_ONLY_HIGH"},
{"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_ONLY_HIGH"}
],
"top_p":0.5,
"max_tokens": 2048,
"temperature": 1.0,
"top_k": 5
}
]}
Expand Down Expand Up @@ -47,6 +57,17 @@ class GeminiClient:
of AutoGen.
"""

# Mapping, where Key is a term used by Autogen, and Value is a term used by Gemini
PARAMS_MAPPING = {
"max_tokens": "max_output_tokens",
# "n": "candidate_count", # Gemini supports only `n=1`
"stop_sequences": "stop_sequences",
"temperature": "temperature",
"top_p": "top_p",
"top_k": "top_k",
"max_output_tokens": "max_output_tokens",
}

def __init__(self, **kwargs):
self.api_key = kwargs.get("api_key", None)
if not self.api_key:
Expand Down Expand Up @@ -93,12 +114,15 @@ def create(self, params: Dict) -> ChatCompletion:
messages = params.get("messages", [])
stream = params.get("stream", False)
n_response = params.get("n", 1)
params.get("temperature", 0.5)
params.get("top_p", 1.0)
params.get("max_tokens", 4096)

generation_config = {
gemini_term: params[autogen_term]
for autogen_term, gemini_term in self.PARAMS_MAPPING.items()
if autogen_term in params
}
safety_settings = params.get("safety_settings", {})

if stream:
# warn user that streaming is not supported
warnings.warn(
"Streaming is not supported for Gemini yet, and it will have no effect. Please set stream=False.",
UserWarning,
Expand All @@ -112,7 +136,9 @@ def create(self, params: Dict) -> ChatCompletion:
gemini_messages = oai_messages_to_gemini_messages(messages)

# we use chat model by default
model = genai.GenerativeModel(model_name)
model = genai.GenerativeModel(
model_name, generation_config=generation_config, safety_settings=safety_settings
)
genai.configure(api_key=self.api_key)
chat = model.start_chat(history=gemini_messages[:-1])
max_retries = 5
Expand Down Expand Up @@ -142,7 +168,9 @@ def create(self, params: Dict) -> ChatCompletion:
elif model_name == "gemini-pro-vision":
# B. handle the vision model
# Gemini's vision model does not support chat history yet
model = genai.GenerativeModel(model_name)
model = genai.GenerativeModel(
model_name, generation_config=generation_config, safety_settings=safety_settings
)
genai.configure(api_key=self.api_key)
# chat = model.start_chat(history=gemini_messages[:-1])
# response = chat.send_message(gemini_messages[-1])
Expand Down

0 comments on commit 51e1ea8

Please sign in to comment.