Skip to content

Commit f24696d

Browse files
luxzolivictordibia
authored andcommitted
Enhance vertexai integration (#3086)
* switch to officially supported Vertex AI message sending + safety setting converion for vertexai * add system instructions * switch to officially supported Vertex AI message sending + safety setting converion for vertexai * fix bug in safety settings conversion * add missing system instructions * add safety settings to send message * add support for credentials objects * add type checkingchange project_id to project arg * add more tests * fix mock creation in test * extend docstring * fix errors with gemini message format in chats * add option for vertexai response validation setting & improve docstring * readding empty message handling * add more tests * extend and improve gemini vertexai jupyter notebook * rename project arg to project_id and GOOGLE_API_KEY env var to GOOGLE_GEMINI_API_KEY * adjust docstring formatting
1 parent 924420a commit f24696d

File tree

9 files changed

+518
-104
lines changed

9 files changed

+518
-104
lines changed

autogen/agentchat/chat.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,15 @@ def __find_async_chat_order(chat_ids: Set[int], prerequisites: List[Prerequisite
107107
return chat_order
108108

109109

110+
def _post_process_carryover_item(carryover_item):
111+
if isinstance(carryover_item, str):
112+
return carryover_item
113+
elif isinstance(carryover_item, dict) and "content" in carryover_item:
114+
return str(carryover_item["content"])
115+
else:
116+
return str(carryover_item)
117+
118+
110119
def __post_carryover_processing(chat_info: Dict[str, Any]) -> None:
111120
iostream = IOStream.get_default()
112121

@@ -116,7 +125,7 @@ def __post_carryover_processing(chat_info: Dict[str, Any]) -> None:
116125
UserWarning,
117126
)
118127
print_carryover = (
119-
("\n").join([t for t in chat_info["carryover"]])
128+
("\n").join([_post_process_carryover_item(t) for t in chat_info["carryover"]])
120129
if isinstance(chat_info["carryover"], list)
121130
else chat_info["carryover"]
122131
)

autogen/agentchat/conversable_agent.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
from openai import BadRequestError
1313

14+
from autogen.agentchat.chat import _post_process_carryover_item
1415
from autogen.exception_utils import InvalidCarryOverType, SenderRequired
1516

1617
from .._pydantic import model_dump
@@ -2364,7 +2365,7 @@ def _process_carryover(self, content: str, kwargs: dict) -> str:
23642365
if isinstance(kwargs["carryover"], str):
23652366
content += "\nContext: \n" + kwargs["carryover"]
23662367
elif isinstance(kwargs["carryover"], list):
2367-
content += "\nContext: \n" + ("\n").join([t for t in kwargs["carryover"]])
2368+
content += "\nContext: \n" + ("\n").join([_post_process_carryover_item(t) for t in kwargs["carryover"]])
23682369
else:
23692370
raise InvalidCarryOverType(
23702371
"Carryover should be a string or a list of strings. Not adding carryover to the message."

autogen/oai/gemini.py

+82-17
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
"config_list": [{
77
"api_type": "google",
88
"model": "gemini-pro",
9-
"api_key": os.environ.get("GOOGLE_API_KEY"),
9+
"api_key": os.environ.get("GOOGLE_GEMINI_API_KEY"),
1010
"safety_settings": [
1111
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_ONLY_HIGH"},
1212
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_ONLY_HIGH"},
@@ -32,6 +32,7 @@
3232
from __future__ import annotations
3333

3434
import base64
35+
import logging
3536
import os
3637
import random
3738
import re
@@ -45,13 +46,19 @@
4546
import vertexai
4647
from google.ai.generativelanguage import Content, Part
4748
from google.api_core.exceptions import InternalServerError
49+
from google.auth.credentials import Credentials
4850
from openai.types.chat import ChatCompletion
4951
from openai.types.chat.chat_completion import ChatCompletionMessage, Choice
5052
from openai.types.completion_usage import CompletionUsage
5153
from PIL import Image
5254
from vertexai.generative_models import Content as VertexAIContent
5355
from vertexai.generative_models import GenerativeModel
56+
from vertexai.generative_models import HarmBlockThreshold as VertexAIHarmBlockThreshold
57+
from vertexai.generative_models import HarmCategory as VertexAIHarmCategory
5458
from vertexai.generative_models import Part as VertexAIPart
59+
from vertexai.generative_models import SafetySetting as VertexAISafetySetting
60+
61+
logger = logging.getLogger(__name__)
5562

5663

5764
class GeminiClient:
@@ -81,29 +88,36 @@ def _initialize_vertexai(self, **params):
8188
vertexai_init_args["project"] = params["project_id"]
8289
if "location" in params:
8390
vertexai_init_args["location"] = params["location"]
91+
if "credentials" in params:
92+
assert isinstance(
93+
params["credentials"], Credentials
94+
), "Object type google.auth.credentials.Credentials is expected!"
95+
vertexai_init_args["credentials"] = params["credentials"]
8496
if vertexai_init_args:
8597
vertexai.init(**vertexai_init_args)
8698

8799
def __init__(self, **kwargs):
88100
"""Uses either either api_key for authentication from the LLM config
89-
(specifying the GOOGLE_API_KEY environment variable also works),
101+
(specifying the GOOGLE_GEMINI_API_KEY environment variable also works),
90102
or follows the Google authentication mechanism for VertexAI in Google Cloud if no api_key is specified,
91-
where project_id and location can also be passed as parameters. Service account key file can also be used.
92-
If neither a service account key file, nor the api_key are passed, then the default credentials will be used,
93-
which could be a personal account if the user is already authenticated in, like in Google Cloud Shell.
103+
where project_id and location can also be passed as parameters. Previously created credentials object can be provided,
104+
or a Service account key file can also be used. If neither a service account key file, nor the api_key are passed,
105+
then the default credentials will be used, which could be a personal account if the user is already authenticated in,
106+
like in Google Cloud Shell.
94107
95108
Args:
96109
api_key (str): The API key for using Gemini.
110+
credentials (google.auth.credentials.Credentials): credentials to be used for authentication with vertexai.
97111
google_application_credentials (str): Path to the JSON service account key file of the service account.
98-
Alternatively, the GOOGLE_APPLICATION_CREDENTIALS environment variable
99-
can also be set instead of using this argument.
112+
Alternatively, the GOOGLE_APPLICATION_CREDENTIALS environment variable
113+
can also be set instead of using this argument.
100114
project_id (str): Google Cloud project id, which is only valid in case no API key is specified.
101115
location (str): Compute region to be used, like 'us-west1'.
102-
This parameter is only valid in case no API key is specified.
116+
This parameter is only valid in case no API key is specified.
103117
"""
104118
self.api_key = kwargs.get("api_key", None)
105119
if not self.api_key:
106-
self.api_key = os.getenv("GOOGLE_API_KEY")
120+
self.api_key = os.getenv("GOOGLE_GEMINI_API_KEY")
107121
if self.api_key is None:
108122
self.use_vertexai = True
109123
self._initialize_vertexai(**kwargs)
@@ -159,13 +173,18 @@ def create(self, params: Dict) -> ChatCompletion:
159173
messages = params.get("messages", [])
160174
stream = params.get("stream", False)
161175
n_response = params.get("n", 1)
176+
system_instruction = params.get("system_instruction", None)
177+
response_validation = params.get("response_validation", True)
162178

163179
generation_config = {
164180
gemini_term: params[autogen_term]
165181
for autogen_term, gemini_term in self.PARAMS_MAPPING.items()
166182
if autogen_term in params
167183
}
168-
safety_settings = params.get("safety_settings", {})
184+
if self.use_vertexai:
185+
safety_settings = GeminiClient._to_vertexai_safety_settings(params.get("safety_settings", {}))
186+
else:
187+
safety_settings = params.get("safety_settings", {})
169188

170189
if stream:
171190
warnings.warn(
@@ -181,20 +200,29 @@ def create(self, params: Dict) -> ChatCompletion:
181200
gemini_messages = self._oai_messages_to_gemini_messages(messages)
182201
if self.use_vertexai:
183202
model = GenerativeModel(
184-
model_name, generation_config=generation_config, safety_settings=safety_settings
203+
model_name,
204+
generation_config=generation_config,
205+
safety_settings=safety_settings,
206+
system_instruction=system_instruction,
185207
)
208+
chat = model.start_chat(history=gemini_messages[:-1], response_validation=response_validation)
186209
else:
187210
# we use chat model by default
188211
model = genai.GenerativeModel(
189-
model_name, generation_config=generation_config, safety_settings=safety_settings
212+
model_name,
213+
generation_config=generation_config,
214+
safety_settings=safety_settings,
215+
system_instruction=system_instruction,
190216
)
191217
genai.configure(api_key=self.api_key)
192-
chat = model.start_chat(history=gemini_messages[:-1])
218+
chat = model.start_chat(history=gemini_messages[:-1])
193219
max_retries = 5
194220
for attempt in range(max_retries):
195221
ans = None
196222
try:
197-
response = chat.send_message(gemini_messages[-1], stream=stream)
223+
response = chat.send_message(
224+
gemini_messages[-1].parts, stream=stream, safety_settings=safety_settings
225+
)
198226
except InternalServerError:
199227
delay = 5 * (2**attempt)
200228
warnings.warn(
@@ -218,16 +246,22 @@ def create(self, params: Dict) -> ChatCompletion:
218246
# B. handle the vision model
219247
if self.use_vertexai:
220248
model = GenerativeModel(
221-
model_name, generation_config=generation_config, safety_settings=safety_settings
249+
model_name,
250+
generation_config=generation_config,
251+
safety_settings=safety_settings,
252+
system_instruction=system_instruction,
222253
)
223254
else:
224255
model = genai.GenerativeModel(
225-
model_name, generation_config=generation_config, safety_settings=safety_settings
256+
model_name,
257+
generation_config=generation_config,
258+
safety_settings=safety_settings,
259+
system_instruction=system_instruction,
226260
)
227261
genai.configure(api_key=self.api_key)
228262
# Gemini's vision model does not support chat history yet
229263
# chat = model.start_chat(history=gemini_messages[:-1])
230-
# response = chat.send_message(gemini_messages[-1])
264+
# response = chat.send_message(gemini_messages[-1].parts)
231265
user_message = self._oai_content_to_gemini_content(messages[-1]["content"])
232266
if len(messages) > 2:
233267
warnings.warn(
@@ -270,6 +304,8 @@ def _oai_content_to_gemini_content(self, content: Union[str, List]) -> List:
270304
"""Convert content from OAI format to Gemini format"""
271305
rst = []
272306
if isinstance(content, str):
307+
if content == "":
308+
content = "empty" # Empty content is not allowed.
273309
if self.use_vertexai:
274310
rst.append(VertexAIPart.from_text(content))
275311
else:
@@ -372,6 +408,35 @@ def _oai_messages_to_gemini_messages(self, messages: list[Dict[str, Any]]) -> li
372408

373409
return rst
374410

411+
@staticmethod
412+
def _to_vertexai_safety_settings(safety_settings):
413+
"""Convert safety settings to VertexAI format if needed,
414+
like when specifying them in the OAI_CONFIG_LIST
415+
"""
416+
if isinstance(safety_settings, list) and all(
417+
[
418+
isinstance(safety_setting, dict) and not isinstance(safety_setting, VertexAISafetySetting)
419+
for safety_setting in safety_settings
420+
]
421+
):
422+
vertexai_safety_settings = []
423+
for safety_setting in safety_settings:
424+
if safety_setting["category"] not in VertexAIHarmCategory.__members__:
425+
invalid_category = safety_setting["category"]
426+
logger.error(f"Safety setting category {invalid_category} is invalid")
427+
elif safety_setting["threshold"] not in VertexAIHarmBlockThreshold.__members__:
428+
invalid_threshold = safety_setting["threshold"]
429+
logger.error(f"Safety threshold {invalid_threshold} is invalid")
430+
else:
431+
vertexai_safety_setting = VertexAISafetySetting(
432+
category=safety_setting["category"],
433+
threshold=safety_setting["threshold"],
434+
)
435+
vertexai_safety_settings.append(vertexai_safety_setting)
436+
return vertexai_safety_settings
437+
else:
438+
return safety_settings
439+
375440

376441
def _to_pil(data: str) -> Image.Image:
377442
"""

autogen/oai/openai_utils.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,15 @@
1313
from openai.types.beta.assistant import Assistant
1414
from packaging.version import parse
1515

16-
NON_CACHE_KEY = ["api_key", "base_url", "api_type", "api_version", "azure_ad_token", "azure_ad_token_provider"]
16+
NON_CACHE_KEY = [
17+
"api_key",
18+
"base_url",
19+
"api_type",
20+
"api_version",
21+
"azure_ad_token",
22+
"azure_ad_token_provider",
23+
"credentials",
24+
]
1725
DEFAULT_AZURE_API_VERSION = "2024-02-01"
1826
OAI_PRICE1K = {
1927
# https://openai.com/api/pricing/

test/agentchat/test_chats.py

+11
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
import autogen
1212
from autogen import AssistantAgent, GroupChat, GroupChatManager, UserProxyAgent, filter_config, initiate_chats
13+
from autogen.agentchat.chat import _post_process_carryover_item
1314

1415
sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
1516
from conftest import reason, skip_openai # noqa: E402
@@ -620,6 +621,15 @@ def my_writing_task(sender, recipient, context):
620621
print(chat_results[1].summary, chat_results[1].cost)
621622

622623

624+
def test_post_process_carryover_item():
625+
gemini_carryover_item = {"content": "How can I help you?", "role": "model"}
626+
assert (
627+
_post_process_carryover_item(gemini_carryover_item) == gemini_carryover_item["content"]
628+
), "Incorrect carryover postprocessing"
629+
carryover_item = "How can I help you?"
630+
assert _post_process_carryover_item(carryover_item) == carryover_item, "Incorrect carryover postprocessing"
631+
632+
623633
if __name__ == "__main__":
624634
test_chats()
625635
# test_chats_general()
@@ -628,3 +638,4 @@ def my_writing_task(sender, recipient, context):
628638
# test_chats_w_func()
629639
# test_chat_messages_for_summary()
630640
# test_udf_message_in_chats()
641+
test_post_process_carryover_item()

test/agentchat/test_conversable_agent.py

+56
Original file line numberDiff line numberDiff line change
@@ -1463,6 +1463,58 @@ def sample_function():
14631463
)
14641464

14651465

1466+
def test_process_gemini_carryover():
1467+
dummy_agent_1 = ConversableAgent(name="dummy_agent_1", llm_config=False, human_input_mode="NEVER")
1468+
content = "I am your assistant."
1469+
carryover_content = "How can I help you?"
1470+
gemini_kwargs = {"carryover": [{"content": carryover_content}]}
1471+
proc_content = dummy_agent_1._process_carryover(content=content, kwargs=gemini_kwargs)
1472+
assert proc_content == content + "\nContext: \n" + carryover_content, "Incorrect carryover processing"
1473+
1474+
1475+
def test_process_carryover():
1476+
dummy_agent_1 = ConversableAgent(name="dummy_agent_1", llm_config=False, human_input_mode="NEVER")
1477+
content = "I am your assistant."
1478+
carryover = "How can I help you?"
1479+
kwargs = {"carryover": carryover}
1480+
proc_content = dummy_agent_1._process_carryover(content=content, kwargs=kwargs)
1481+
assert proc_content == content + "\nContext: \n" + carryover, "Incorrect carryover processing"
1482+
1483+
carryover_l = ["How can I help you?"]
1484+
kwargs = {"carryover": carryover_l}
1485+
proc_content = dummy_agent_1._process_carryover(content=content, kwargs=kwargs)
1486+
assert proc_content == content + "\nContext: \n" + carryover, "Incorrect carryover processing"
1487+
1488+
proc_content_empty_carryover = dummy_agent_1._process_carryover(content=content, kwargs={"carryover": None})
1489+
assert proc_content_empty_carryover == content, "Incorrect carryover processing"
1490+
1491+
1492+
def test_handle_gemini_carryover():
1493+
dummy_agent_1 = ConversableAgent(name="dummy_agent_1", llm_config=False, human_input_mode="NEVER")
1494+
content = "I am your assistant"
1495+
carryover_content = "How can I help you?"
1496+
gemini_kwargs = {"carryover": [{"content": carryover_content}]}
1497+
proc_content = dummy_agent_1._handle_carryover(message=content, kwargs=gemini_kwargs)
1498+
assert proc_content == content + "\nContext: \n" + carryover_content, "Incorrect carryover processing"
1499+
1500+
1501+
def test_handle_carryover():
1502+
dummy_agent_1 = ConversableAgent(name="dummy_agent_1", llm_config=False, human_input_mode="NEVER")
1503+
content = "I am your assistant."
1504+
carryover = "How can I help you?"
1505+
kwargs = {"carryover": carryover}
1506+
proc_content = dummy_agent_1._handle_carryover(message=content, kwargs=kwargs)
1507+
assert proc_content == content + "\nContext: \n" + carryover, "Incorrect carryover processing"
1508+
1509+
carryover_l = ["How can I help you?"]
1510+
kwargs = {"carryover": carryover_l}
1511+
proc_content = dummy_agent_1._handle_carryover(message=content, kwargs=kwargs)
1512+
assert proc_content == content + "\nContext: \n" + carryover, "Incorrect carryover processing"
1513+
1514+
proc_content_empty_carryover = dummy_agent_1._handle_carryover(message=content, kwargs={"carryover": None})
1515+
assert proc_content_empty_carryover == content, "Incorrect carryover processing"
1516+
1517+
14661518
if __name__ == "__main__":
14671519
# test_trigger()
14681520
# test_context()
@@ -1473,6 +1525,10 @@ def sample_function():
14731525
# test_max_turn()
14741526
# test_process_before_send()
14751527
# test_message_func()
1528+
14761529
test_summary()
14771530
test_adding_duplicate_function_warning()
14781531
# test_function_registration_e2e_sync()
1532+
1533+
test_process_gemini_carryover()
1534+
test_process_carryover()

0 commit comments

Comments
 (0)