6
6
"config_list": [{
7
7
"api_type": "google",
8
8
"model": "gemini-pro",
9
- "api_key": os.environ.get("GOOGLE_API_KEY "),
9
+ "api_key": os.environ.get("GOOGLE_GEMINI_API_KEY "),
10
10
"safety_settings": [
11
11
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_ONLY_HIGH"},
12
12
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_ONLY_HIGH"},
32
32
from __future__ import annotations
33
33
34
34
import base64
35
+ import logging
35
36
import os
36
37
import random
37
38
import re
45
46
import vertexai
46
47
from google .ai .generativelanguage import Content , Part
47
48
from google .api_core .exceptions import InternalServerError
49
+ from google .auth .credentials import Credentials
48
50
from openai .types .chat import ChatCompletion
49
51
from openai .types .chat .chat_completion import ChatCompletionMessage , Choice
50
52
from openai .types .completion_usage import CompletionUsage
51
53
from PIL import Image
52
54
from vertexai .generative_models import Content as VertexAIContent
53
55
from vertexai .generative_models import GenerativeModel
56
+ from vertexai .generative_models import HarmBlockThreshold as VertexAIHarmBlockThreshold
57
+ from vertexai .generative_models import HarmCategory as VertexAIHarmCategory
54
58
from vertexai .generative_models import Part as VertexAIPart
59
+ from vertexai .generative_models import SafetySetting as VertexAISafetySetting
60
+
61
+ logger = logging .getLogger (__name__ )
55
62
56
63
57
64
class GeminiClient :
@@ -81,29 +88,36 @@ def _initialize_vertexai(self, **params):
81
88
vertexai_init_args ["project" ] = params ["project_id" ]
82
89
if "location" in params :
83
90
vertexai_init_args ["location" ] = params ["location" ]
91
+ if "credentials" in params :
92
+ assert isinstance (
93
+ params ["credentials" ], Credentials
94
+ ), "Object type google.auth.credentials.Credentials is expected!"
95
+ vertexai_init_args ["credentials" ] = params ["credentials" ]
84
96
if vertexai_init_args :
85
97
vertexai .init (** vertexai_init_args )
86
98
87
99
def __init__ (self , ** kwargs ):
88
100
"""Uses either either api_key for authentication from the LLM config
89
- (specifying the GOOGLE_API_KEY environment variable also works),
101
+ (specifying the GOOGLE_GEMINI_API_KEY environment variable also works),
90
102
or follows the Google authentication mechanism for VertexAI in Google Cloud if no api_key is specified,
91
- where project_id and location can also be passed as parameters. Service account key file can also be used.
92
- If neither a service account key file, nor the api_key are passed, then the default credentials will be used,
93
- which could be a personal account if the user is already authenticated in, like in Google Cloud Shell.
103
+ where project_id and location can also be passed as parameters. Previously created credentials object can be provided,
104
+ or a Service account key file can also be used. If neither a service account key file, nor the api_key are passed,
105
+ then the default credentials will be used, which could be a personal account if the user is already authenticated in,
106
+ like in Google Cloud Shell.
94
107
95
108
Args:
96
109
api_key (str): The API key for using Gemini.
110
+ credentials (google.auth.credentials.Credentials): credentials to be used for authentication with vertexai.
97
111
google_application_credentials (str): Path to the JSON service account key file of the service account.
98
- Alternatively, the GOOGLE_APPLICATION_CREDENTIALS environment variable
99
- can also be set instead of using this argument.
112
+ Alternatively, the GOOGLE_APPLICATION_CREDENTIALS environment variable
113
+ can also be set instead of using this argument.
100
114
project_id (str): Google Cloud project id, which is only valid in case no API key is specified.
101
115
location (str): Compute region to be used, like 'us-west1'.
102
- This parameter is only valid in case no API key is specified.
116
+ This parameter is only valid in case no API key is specified.
103
117
"""
104
118
self .api_key = kwargs .get ("api_key" , None )
105
119
if not self .api_key :
106
- self .api_key = os .getenv ("GOOGLE_API_KEY " )
120
+ self .api_key = os .getenv ("GOOGLE_GEMINI_API_KEY " )
107
121
if self .api_key is None :
108
122
self .use_vertexai = True
109
123
self ._initialize_vertexai (** kwargs )
@@ -159,13 +173,18 @@ def create(self, params: Dict) -> ChatCompletion:
159
173
messages = params .get ("messages" , [])
160
174
stream = params .get ("stream" , False )
161
175
n_response = params .get ("n" , 1 )
176
+ system_instruction = params .get ("system_instruction" , None )
177
+ response_validation = params .get ("response_validation" , True )
162
178
163
179
generation_config = {
164
180
gemini_term : params [autogen_term ]
165
181
for autogen_term , gemini_term in self .PARAMS_MAPPING .items ()
166
182
if autogen_term in params
167
183
}
168
- safety_settings = params .get ("safety_settings" , {})
184
+ if self .use_vertexai :
185
+ safety_settings = GeminiClient ._to_vertexai_safety_settings (params .get ("safety_settings" , {}))
186
+ else :
187
+ safety_settings = params .get ("safety_settings" , {})
169
188
170
189
if stream :
171
190
warnings .warn (
@@ -181,20 +200,29 @@ def create(self, params: Dict) -> ChatCompletion:
181
200
gemini_messages = self ._oai_messages_to_gemini_messages (messages )
182
201
if self .use_vertexai :
183
202
model = GenerativeModel (
184
- model_name , generation_config = generation_config , safety_settings = safety_settings
203
+ model_name ,
204
+ generation_config = generation_config ,
205
+ safety_settings = safety_settings ,
206
+ system_instruction = system_instruction ,
185
207
)
208
+ chat = model .start_chat (history = gemini_messages [:- 1 ], response_validation = response_validation )
186
209
else :
187
210
# we use chat model by default
188
211
model = genai .GenerativeModel (
189
- model_name , generation_config = generation_config , safety_settings = safety_settings
212
+ model_name ,
213
+ generation_config = generation_config ,
214
+ safety_settings = safety_settings ,
215
+ system_instruction = system_instruction ,
190
216
)
191
217
genai .configure (api_key = self .api_key )
192
- chat = model .start_chat (history = gemini_messages [:- 1 ])
218
+ chat = model .start_chat (history = gemini_messages [:- 1 ])
193
219
max_retries = 5
194
220
for attempt in range (max_retries ):
195
221
ans = None
196
222
try :
197
- response = chat .send_message (gemini_messages [- 1 ], stream = stream )
223
+ response = chat .send_message (
224
+ gemini_messages [- 1 ].parts , stream = stream , safety_settings = safety_settings
225
+ )
198
226
except InternalServerError :
199
227
delay = 5 * (2 ** attempt )
200
228
warnings .warn (
@@ -218,16 +246,22 @@ def create(self, params: Dict) -> ChatCompletion:
218
246
# B. handle the vision model
219
247
if self .use_vertexai :
220
248
model = GenerativeModel (
221
- model_name , generation_config = generation_config , safety_settings = safety_settings
249
+ model_name ,
250
+ generation_config = generation_config ,
251
+ safety_settings = safety_settings ,
252
+ system_instruction = system_instruction ,
222
253
)
223
254
else :
224
255
model = genai .GenerativeModel (
225
- model_name , generation_config = generation_config , safety_settings = safety_settings
256
+ model_name ,
257
+ generation_config = generation_config ,
258
+ safety_settings = safety_settings ,
259
+ system_instruction = system_instruction ,
226
260
)
227
261
genai .configure (api_key = self .api_key )
228
262
# Gemini's vision model does not support chat history yet
229
263
# chat = model.start_chat(history=gemini_messages[:-1])
230
- # response = chat.send_message(gemini_messages[-1])
264
+ # response = chat.send_message(gemini_messages[-1].parts )
231
265
user_message = self ._oai_content_to_gemini_content (messages [- 1 ]["content" ])
232
266
if len (messages ) > 2 :
233
267
warnings .warn (
@@ -270,6 +304,8 @@ def _oai_content_to_gemini_content(self, content: Union[str, List]) -> List:
270
304
"""Convert content from OAI format to Gemini format"""
271
305
rst = []
272
306
if isinstance (content , str ):
307
+ if content == "" :
308
+ content = "empty" # Empty content is not allowed.
273
309
if self .use_vertexai :
274
310
rst .append (VertexAIPart .from_text (content ))
275
311
else :
@@ -372,6 +408,35 @@ def _oai_messages_to_gemini_messages(self, messages: list[Dict[str, Any]]) -> li
372
408
373
409
return rst
374
410
411
+ @staticmethod
412
+ def _to_vertexai_safety_settings (safety_settings ):
413
+ """Convert safety settings to VertexAI format if needed,
414
+ like when specifying them in the OAI_CONFIG_LIST
415
+ """
416
+ if isinstance (safety_settings , list ) and all (
417
+ [
418
+ isinstance (safety_setting , dict ) and not isinstance (safety_setting , VertexAISafetySetting )
419
+ for safety_setting in safety_settings
420
+ ]
421
+ ):
422
+ vertexai_safety_settings = []
423
+ for safety_setting in safety_settings :
424
+ if safety_setting ["category" ] not in VertexAIHarmCategory .__members__ :
425
+ invalid_category = safety_setting ["category" ]
426
+ logger .error (f"Safety setting category { invalid_category } is invalid" )
427
+ elif safety_setting ["threshold" ] not in VertexAIHarmBlockThreshold .__members__ :
428
+ invalid_threshold = safety_setting ["threshold" ]
429
+ logger .error (f"Safety threshold { invalid_threshold } is invalid" )
430
+ else :
431
+ vertexai_safety_setting = VertexAISafetySetting (
432
+ category = safety_setting ["category" ],
433
+ threshold = safety_setting ["threshold" ],
434
+ )
435
+ vertexai_safety_settings .append (vertexai_safety_setting )
436
+ return vertexai_safety_settings
437
+ else :
438
+ return safety_settings
439
+
375
440
376
441
def _to_pil (data : str ) -> Image .Image :
377
442
"""
0 commit comments