Skip to content

Commit

Permalink
feat: 加入Gemini Pro (Vision) 支持 #1039
Browse files Browse the repository at this point in the history
  • Loading branch information
GaiZhenbiao committed Feb 20, 2024
1 parent c904b2a commit 1318660
Show file tree
Hide file tree
Showing 9 changed files with 113 additions and 9 deletions.
2 changes: 1 addition & 1 deletion config_example.json
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

//== API 配置 ==
"openai_api_key": "", // 你的 OpenAI API Key,一般必填,若空缺则需在图形界面中填入API Key
"google_palm_api_key": "", // 你的 Google PaLM API Key,用于 Google PaLM 对话模型
"google_genai_api_key": "", // 你的 Google PaLM API Key,用于 Google PaLM 对话模型
"xmchat_api_key": "", // 你的 xmchat API Key,用于 XMChat 对话模型
"minimax_api_key": "", // 你的 MiniMax API Key,用于 MiniMax 对话模型
"minimax_group_id": "", // 你的 MiniMax Group ID,用于 MiniMax 对话模型
Expand Down
5 changes: 3 additions & 2 deletions locale/en_US.json
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@
"本地编制索引": "Local indexing",
"是否在本地编制知识库索引?如果是,可以在使用本地模型时离线使用知识库,否则使用OpenAI服务来编制索引(需要OpenAI API Key)。请确保你的电脑有至少16GB内存。本地索引模型需要从互联网下载。": "Do you want to index the knowledge base locally? If so, you can use the knowledge base offline when using the local model, otherwise use the OpenAI service to index (requires OpenAI API Key). Make sure your computer has at least 16GB of memory. The local index model needs to be downloaded from the Internet.",
"现在开始设置其他在线模型的API Key": "Start setting the API Key for other online models",
"是否设置默认 Google Palm API 密钥?如果设置,软件启动时会自动加载该API Key,无需在 UI 中手动输入。如果不设置,可以在软件启动后手动输入 API Key。": "Set the default Google Palm API Key? If set, the API Key will be automatically loaded when the software starts, and there is no need to manually enter it in the UI. If not set, you can manually enter the API Key after the software starts.",
"是否设置默认 Google AI Studio API 密钥?如果设置,软件启动时会自动加载该API Key,无需在 UI 中手动输入。如果不设置,可以在软件启动后手动输入 API Key。": "Set the default Google Palm API Key? If set, the API Key will be automatically loaded when the software starts, and there is no need to manually enter it in the UI. If not set, you can manually enter the API Key after the software starts.",
"是否设置默认 XMChat API 密钥?如果设置,软件启动时会自动加载该API Key,无需在 UI 中手动输入。如果不设置,可以在软件启动后手动输入 API Key。": "Set the default XMChat API Key? If set, the API Key will be automatically loaded when the software starts, and there is no need to manually enter it in the UI. If not set, you can manually enter the API Key after the software starts.",
"是否设置默认 MiniMax API 密钥和 Group ID?如果设置,软件启动时会自动加载该API Key,无需在 UI 中手动输入。如果不设置,将无法使用 MiniMax 模型。": "Set the default MiniMax API Key and Group ID? If set, the API Key will be automatically loaded when the software starts, and there is no need to manually enter it in the UI. If not set, the MiniMax model will not be available.",
"你的": "Your ",
Expand Down Expand Up @@ -227,5 +227,6 @@
"设置完成。现在请重启本程序。": "Setup completed. Please restart this program now.",
"你设置了 ": "You set ",
" 为: ": " as: ",
"输入的不是数字,将使用默认值。": "The input is not a number, the default value will be used."
"输入的不是数字,将使用默认值。": "The input is not a number, the default value will be used.",
"由于下面的原因,Google 拒绝返回 Gemini 的回答:\n\n": "For the following reasons, Google refuses to return Gemini's response:\n\n",
}
11 changes: 7 additions & 4 deletions modules/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,10 +108,13 @@ def load_config_to_environ(key_list):

HIDE_MY_KEY = config.get("hide_my_key", False)

google_palm_api_key = config.get("google_palm_api_key", "")
google_palm_api_key = os.environ.get(
"GOOGLE_PALM_API_KEY", google_palm_api_key)
os.environ["GOOGLE_PALM_API_KEY"] = google_palm_api_key
google_genai_api_key = os.environ.get(
"GOOGLE_PALM_API_KEY", None)
google_genai_api_key = os.environ.get(
"GOOGLE_GENAI_API_KEY", None)
google_genai_api_key = config.get("google_palm_api_key", google_genai_api_key)
google_genai_api_key = config.get("google_genai_api_key", google_genai_api_key)
os.environ["GOOGLE_GENAI_API_KEY"] = google_genai_api_key

xmchat_api_key = config.get("xmchat_api_key", "")
os.environ["XMCHAT_API_KEY"] = xmchat_api_key
Expand Down
81 changes: 81 additions & 0 deletions modules/models/GoogleGemini.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import json
import logging
import textwrap
import uuid

import google.generativeai as genai
import gradio as gr
import PIL
import requests

from modules.presets import i18n

from ..index_func import construct_index
from ..utils import count_token
from .base_model import BaseLLMModel


class GoogleGeminiClient(BaseLLMModel):
def __init__(self, model_name, api_key, user_name="") -> None:
super().__init__(model_name=model_name, user=user_name)
self.api_key = api_key
if "vision" in model_name.lower():
self.multimodal = True
else:
self.multimodal = False
self.image_paths = []

def _get_gemini_style_input(self):
self.history.extend([{"role": "image", "content": i} for i in self.image_paths])
self.image_paths = []
messages = []
for item in self.history:
if item["role"] == "image":
messages.append(PIL.Image.open(item["content"]))
else:
messages.append(item["content"])
return messages

def to_markdown(self, text):
text = text.replace("•", " *")
return textwrap.indent(text, "> ", predicate=lambda _: True)

def handle_file_upload(self, files, chatbot, language):
if files:
if self.multimodal:
for file in files:
if file.name:
self.image_paths.append(file.name)
chatbot = chatbot + [((file.name,), None)]
return None, chatbot, None
else:
construct_index(self.api_key, file_src=files)
status = i18n("索引构建完成")
return gr.Files.update(), chatbot, status

def get_answer_at_once(self):
genai.configure(api_key=self.api_key)
messages = self._get_gemini_style_input()
model = genai.GenerativeModel(self.model_name)
response = model.generate_content(messages)
try:
return self.to_markdown(response.text), len(response.text)
except ValueError:
return (
i18n("由于下面的原因,Google 拒绝返回 Gemini 的回答:\n\n")
+ str(response.prompt_feedback),
0,
)

def get_answer_stream_iter(self):
genai.configure(api_key=self.api_key)
messages = self._get_gemini_style_input()
model = genai.GenerativeModel(self.model_name)
response = model.generate_content(messages, stream=True)
partial_text = ""
for i in response:
response = i.text
partial_text += response
yield partial_text
self.all_token_counts[-1] = count_token(partial_text)
yield partial_text
3 changes: 3 additions & 0 deletions modules/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ class ModelType(Enum):
OpenAIVision = 16
ERNIE = 17
DALLE3 = 18
GoogleGemini = 19

@classmethod
def get_type(cls, model_name: str):
Expand Down Expand Up @@ -184,6 +185,8 @@ def get_type(cls, model_name: str):
model_type = ModelType.ChuanhuAgent
elif "palm" in model_name_lower:
model_type = ModelType.GooglePaLM
elif "gemini" in model_name_lower:
model_type = ModelType.GoogleGemini
elif "midjourney" in model_name_lower:
model_type = ModelType.Midjourney
elif "azure" in model_name_lower or "api" in model_name_lower:
Expand Down
7 changes: 6 additions & 1 deletion modules/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,14 @@ def get_model(
msg = i18n("启用的工具:") + ", ".join([i.name for i in model.tools])
elif model_type == ModelType.GooglePaLM:
from .GooglePaLM import Google_PaLM_Client
access_key = os.environ.get("GOOGLE_PALM_API_KEY", access_key)
access_key = os.environ.get("GOOGLE_GENAI_API_KEY", access_key)
model = Google_PaLM_Client(
model_name, access_key, user_name=user_name)
elif model_type == ModelType.GoogleGemini:
from .GoogleGemini import GoogleGeminiClient
access_key = os.environ.get("GOOGLE_GENAI_API_KEY", access_key)
model = GoogleGeminiClient(
model_name, access_key, user_name=user_name)
elif model_type == ModelType.LangchainChat:
from .Azure import Azure_OpenAI_Client
model = Azure_OpenAI_Client(model_name, user_name=user_name)
Expand Down
10 changes: 10 additions & 0 deletions modules/presets.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@
"川虎助理",
"川虎助理 Pro",
"DALL-E 3",
"Gemini Pro",
"Gemini Pro Vision",
"GooglePaLM",
"xmchat",
"Azure OpenAI",
Expand Down Expand Up @@ -169,6 +171,14 @@
"model_name": "ERNIE-Bot-4",
"token_limit": 1024,
},
"Gemini Pro": {
"model_name": "gemini-pro",
"token_limit": 30720,
},
"Gemini Pro Vision": {
"model_name": "gemini-pro-vision",
"token_limit": 30720,
}
}

if os.environ.get('HIDE_LOCAL_MODELS', 'false') == 'true':
Expand Down
2 changes: 1 addition & 1 deletion modules/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1097,7 +1097,7 @@ def setup_wizard():
type=ConfigType.Password,
)
],
"是否设置默认 Google Palm API 密钥?如果设置,软件启动时会自动加载该API Key,无需在 UI 中手动输入。如果不设置,可以在软件启动后手动输入 API Key。",
"是否设置默认 Google AI Studio API 密钥?如果设置,软件启动时会自动加载该API Key,无需在 UI 中手动输入。如果不设置,可以在软件启动后手动输入 API Key。",
)
# XMChat
wizard.set(
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ faiss-cpu==1.7.4
duckduckgo-search>=4.1.1
arxiv
wikipedia
google-cloud-aiplatform
google.generativeai
unstructured
google-api-python-client
Expand Down

0 comments on commit 1318660

Please sign in to comment.