Skip to content

Commit f1551c0

Browse files
committed
remove vertexai safety settings
1 parent e265cc2 commit f1551c0

File tree

2 files changed

+2
-63
lines changed

2 files changed

+2
-63
lines changed

autogen/oai/gemini.py

+2-28
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@
3232
from __future__ import annotations
3333

3434
import base64
35-
import logging
3635
import os
3736
import random
3837
import re
@@ -44,7 +43,6 @@
4443
import google.generativeai as genai
4544
import requests
4645
import vertexai
47-
from flaml.automl.logger import logger_formatter
4846
from google.ai.generativelanguage import Content, Part
4947
from google.api_core.exceptions import InternalServerError
5048
from openai.types.chat import ChatCompletion
@@ -53,12 +51,7 @@
5351
from PIL import Image
5452
from vertexai.generative_models import Content as VertexAIContent
5553
from vertexai.generative_models import GenerativeModel
56-
from vertexai.generative_models import HarmBlockThreshold as VertexAIHarmBlockThreshold
57-
from vertexai.generative_models import HarmCategory as VertexAIHarmCategory
5854
from vertexai.generative_models import Part as VertexAIPart
59-
from vertexai.generative_models import SafetySetting as VertexAISafetySetting
60-
61-
logger = logging.getLogger(__name__)
6255

6356

6457
class GeminiClient:
@@ -173,7 +166,6 @@ def create(self, params: Dict) -> ChatCompletion:
173166
if autogen_term in params
174167
}
175168
safety_settings = params.get("safety_settings", {})
176-
vertexai_safety_settings = GeminiClient._to_vertexai_safety_settings(safety_settings)
177169

178170
if stream:
179171
warnings.warn(
@@ -189,7 +181,7 @@ def create(self, params: Dict) -> ChatCompletion:
189181
gemini_messages = self._oai_messages_to_gemini_messages(messages)
190182
if self.use_vertexai:
191183
model = GenerativeModel(
192-
model_name, generation_config=generation_config, safety_settings=vertexai_safety_settings
184+
model_name, generation_config=generation_config, safety_settings=safety_settings
193185
)
194186
else:
195187
# we use chat model by default
@@ -226,7 +218,7 @@ def create(self, params: Dict) -> ChatCompletion:
226218
# B. handle the vision model
227219
if self.use_vertexai:
228220
model = GenerativeModel(
229-
model_name, generation_config=generation_config, safety_settings=vertexai_safety_settings
221+
model_name, generation_config=generation_config, safety_settings=safety_settings
230222
)
231223
else:
232224
model = genai.GenerativeModel(
@@ -380,24 +372,6 @@ def _oai_messages_to_gemini_messages(self, messages: list[Dict[str, Any]]) -> li
380372

381373
return rst
382374

383-
@staticmethod
384-
def _to_vertexai_safety_settings(safety_settings):
385-
vertexai_safety_settings = []
386-
for safety_setting in safety_settings:
387-
if safety_setting["category"] not in VertexAIHarmCategory.__members__:
388-
invalid_category = safety_setting["category"]
389-
logger.error(f"Safety setting category {invalid_category} is invalid")
390-
elif safety_setting["threshold"] not in VertexAIHarmBlockThreshold.__members__:
391-
invalid_threshold = safety_setting["threshold"]
392-
logger.error(f"Safety threshold {invalid_threshold} is invalid")
393-
else:
394-
vertexai_safety_setting = VertexAISafetySetting(
395-
category=safety_setting["category"],
396-
threshold=safety_setting["threshold"],
397-
)
398-
vertexai_safety_settings.append(vertexai_safety_setting)
399-
return vertexai_safety_settings
400-
401375

402376
def _to_pil(data: str) -> Image.Image:
403377
"""

test/oai/test_gemini.py

-35
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,6 @@
44

55
try:
66
from google.api_core.exceptions import InternalServerError
7-
from vertexai.generative_models import HarmBlockThreshold as VertexAIHarmBlockThreshold
8-
from vertexai.generative_models import HarmCategory as VertexAIHarmCategory
9-
from vertexai.generative_models import SafetySetting as VertexAISafetySetting
107

118
from autogen.oai.gemini import GeminiClient
129

@@ -55,38 +52,6 @@ def test_valid_initialization(gemini_client):
5552
assert gemini_client.api_key == "fake_api_key", "API Key should be correctly set"
5653

5754

58-
@pytest.mark.skipif(skip, reason="Google GenAI dependency is not installed")
59-
def test_vertexai_safety_setting_conversion(gemini_client):
60-
safety_settings = [
61-
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_ONLY_HIGH"},
62-
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_ONLY_HIGH"},
63-
{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_ONLY_HIGH"},
64-
{"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_ONLY_HIGH"},
65-
]
66-
converted_safety_settings = GeminiClient._to_vertexai_safety_settings(safety_settings)
67-
harm_categories = [
68-
VertexAIHarmCategory.HARM_CATEGORY_HARASSMENT,
69-
VertexAIHarmCategory.HARM_CATEGORY_HATE_SPEECH,
70-
VertexAIHarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,
71-
VertexAIHarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
72-
]
73-
expected_safety_settings = [
74-
VertexAISafetySetting(category=category, threshold=VertexAIHarmBlockThreshold.BLOCK_ONLY_HIGH)
75-
for category in harm_categories
76-
]
77-
78-
def compare_safety_settings(converted_safety_settings, expected_safety_settings):
79-
for i, expected_setting in enumerate(expected_safety_settings):
80-
converted_setting = converted_safety_settings[i]
81-
yield expected_setting.to_dict() == converted_setting.to_dict()
82-
83-
assert len(converted_safety_settings) == len(
84-
expected_safety_settings
85-
), "The length of the safety settings is incorrect"
86-
settings_comparison = compare_safety_settings(converted_safety_settings, expected_safety_settings)
87-
assert all(settings_comparison), "Converted safety settings are incorrect"
88-
89-
9055
@pytest.mark.skipif(skip, reason="Google GenAI dependency is not installed")
9156
def test_gemini_message_handling(gemini_client):
9257
messages = [

0 commit comments

Comments
 (0)