32
32
from __future__ import annotations
33
33
34
34
import base64
35
- import logging
36
35
import os
37
36
import random
38
37
import re
44
43
import google .generativeai as genai
45
44
import requests
46
45
import vertexai
47
- from flaml .automl .logger import logger_formatter
48
46
from google .ai .generativelanguage import Content , Part
49
47
from google .api_core .exceptions import InternalServerError
50
48
from openai .types .chat import ChatCompletion
53
51
from PIL import Image
54
52
from vertexai .generative_models import Content as VertexAIContent
55
53
from vertexai .generative_models import GenerativeModel
56
- from vertexai .generative_models import HarmBlockThreshold as VertexAIHarmBlockThreshold
57
- from vertexai .generative_models import HarmCategory as VertexAIHarmCategory
58
54
from vertexai .generative_models import Part as VertexAIPart
59
- from vertexai .generative_models import SafetySetting as VertexAISafetySetting
60
-
61
- logger = logging .getLogger (__name__ )
62
55
63
56
64
57
class GeminiClient :
@@ -173,7 +166,6 @@ def create(self, params: Dict) -> ChatCompletion:
173
166
if autogen_term in params
174
167
}
175
168
safety_settings = params .get ("safety_settings" , {})
176
- vertexai_safety_settings = GeminiClient ._to_vertexai_safety_settings (safety_settings )
177
169
178
170
if stream :
179
171
warnings .warn (
@@ -189,7 +181,7 @@ def create(self, params: Dict) -> ChatCompletion:
189
181
gemini_messages = self ._oai_messages_to_gemini_messages (messages )
190
182
if self .use_vertexai :
191
183
model = GenerativeModel (
192
- model_name , generation_config = generation_config , safety_settings = vertexai_safety_settings
184
+ model_name , generation_config = generation_config , safety_settings = safety_settings
193
185
)
194
186
else :
195
187
# we use chat model by default
@@ -226,7 +218,7 @@ def create(self, params: Dict) -> ChatCompletion:
226
218
# B. handle the vision model
227
219
if self .use_vertexai :
228
220
model = GenerativeModel (
229
- model_name , generation_config = generation_config , safety_settings = vertexai_safety_settings
221
+ model_name , generation_config = generation_config , safety_settings = safety_settings
230
222
)
231
223
else :
232
224
model = genai .GenerativeModel (
@@ -380,24 +372,6 @@ def _oai_messages_to_gemini_messages(self, messages: list[Dict[str, Any]]) -> li
380
372
381
373
return rst
382
374
383
- @staticmethod
384
- def _to_vertexai_safety_settings (safety_settings ):
385
- vertexai_safety_settings = []
386
- for safety_setting in safety_settings :
387
- if safety_setting ["category" ] not in VertexAIHarmCategory .__members__ :
388
- invalid_category = safety_setting ["category" ]
389
- logger .error (f"Safety setting category { invalid_category } is invalid" )
390
- elif safety_setting ["threshold" ] not in VertexAIHarmBlockThreshold .__members__ :
391
- invalid_threshold = safety_setting ["threshold" ]
392
- logger .error (f"Safety threshold { invalid_threshold } is invalid" )
393
- else :
394
- vertexai_safety_setting = VertexAISafetySetting (
395
- category = safety_setting ["category" ],
396
- threshold = safety_setting ["threshold" ],
397
- )
398
- vertexai_safety_settings .append (vertexai_safety_setting )
399
- return vertexai_safety_settings
400
-
401
375
402
376
def _to_pil (data : str ) -> Image .Image :
403
377
"""
0 commit comments