Skip to content

Commit b016a84

Browse files
authored
feat: add Google embedding support & update setup (#550) bump:patch
1 parent 159f4da commit b016a84

File tree

5 files changed

+109
-15
lines changed

5 files changed

+109
-15
lines changed

flowsettings.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626

2727
KH_ENABLE_FIRST_SETUP = True
2828
KH_DEMO_MODE = config("KH_DEMO_MODE", default=False, cast=bool)
29+
KH_OLLAMA_URL = config("KH_OLLAMA_URL", default="http://localhost:11434/v1/")
2930

3031
# App can be ran from anywhere and it's not trivial to decide where to store app data.
3132
# So let's use the same directory as the flowsetting.py file.
@@ -162,7 +163,7 @@
162163
KH_LLMS["ollama"] = {
163164
"spec": {
164165
"__type__": "kotaemon.llms.ChatOpenAI",
165-
"base_url": "http://localhost:11434/v1/",
166+
"base_url": KH_OLLAMA_URL,
166167
"model": config("LOCAL_MODEL", default="llama3.1:8b"),
167168
"api_key": "ollama",
168169
},
@@ -171,7 +172,7 @@
171172
KH_EMBEDDINGS["ollama"] = {
172173
"spec": {
173174
"__type__": "kotaemon.embeddings.OpenAIEmbeddings",
174-
"base_url": "http://localhost:11434/v1/",
175+
"base_url": KH_OLLAMA_URL,
175176
"model": config("LOCAL_MODEL_EMBEDDINGS", default="nomic-embed-text"),
176177
"api_key": "ollama",
177178
},
@@ -195,11 +196,11 @@
195196
},
196197
"default": False,
197198
}
198-
KH_LLMS["gemini"] = {
199+
KH_LLMS["google"] = {
199200
"spec": {
200201
"__type__": "kotaemon.llms.chats.LCGeminiChat",
201-
"model_name": "gemini-1.5-pro",
202-
"api_key": "your-key",
202+
"model_name": "gemini-1.5-flash",
203+
"api_key": config("GOOGLE_API_KEY", default="your-key"),
203204
},
204205
"default": False,
205206
}
@@ -231,6 +232,13 @@
231232
},
232233
"default": False,
233234
}
235+
KH_EMBEDDINGS["google"] = {
236+
"spec": {
237+
"__type__": "kotaemon.embeddings.LCGoogleEmbeddings",
238+
"model": "models/text-embedding-004",
239+
"google_api_key": config("GOOGLE_API_KEY", default="your-key"),
240+
}
241+
}
234242
# KH_EMBEDDINGS["huggingface"] = {
235243
# "spec": {
236244
# "__type__": "kotaemon.embeddings.LCHuggingFaceEmbeddings",

libs/kotaemon/kotaemon/embeddings/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from .langchain_based import (
55
LCAzureOpenAIEmbeddings,
66
LCCohereEmbeddings,
7+
LCGoogleEmbeddings,
78
LCHuggingFaceEmbeddings,
89
LCOpenAIEmbeddings,
910
)
@@ -18,6 +19,7 @@
1819
"LCAzureOpenAIEmbeddings",
1920
"LCCohereEmbeddings",
2021
"LCHuggingFaceEmbeddings",
22+
"LCGoogleEmbeddings",
2123
"OpenAIEmbeddings",
2224
"AzureOpenAIEmbeddings",
2325
"FastEmbedEmbeddings",

libs/kotaemon/kotaemon/embeddings/langchain_based.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,3 +219,38 @@ def _get_lc_class(self):
219219
from langchain.embeddings import HuggingFaceBgeEmbeddings
220220

221221
return HuggingFaceBgeEmbeddings
222+
223+
224+
class LCGoogleEmbeddings(LCEmbeddingMixin, BaseEmbeddings):
225+
"""Wrapper around Langchain's Google GenAI embedding, focusing on key parameters"""
226+
227+
google_api_key: str = Param(
228+
help="API key (https://aistudio.google.com/app/apikey)",
229+
default=None,
230+
required=True,
231+
)
232+
model: str = Param(
233+
help="Model name to use (https://ai.google.dev/gemini-api/docs/models/gemini#text-embedding-and-embedding)", # noqa
234+
default="models/text-embedding-004",
235+
required=True,
236+
)
237+
238+
def __init__(
239+
self,
240+
model: str = "models/text-embedding-004",
241+
google_api_key: Optional[str] = None,
242+
**params,
243+
):
244+
super().__init__(
245+
model=model,
246+
google_api_key=google_api_key,
247+
**params,
248+
)
249+
250+
def _get_lc_class(self):
251+
try:
252+
from langchain_google_genai import GoogleGenerativeAIEmbeddings
253+
except ImportError:
254+
raise ImportError("Please install langchain-google-genai")
255+
256+
return GoogleGenerativeAIEmbeddings

libs/ktem/ktem/embeddings/manager.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ def load_vendors(self):
5757
AzureOpenAIEmbeddings,
5858
FastEmbedEmbeddings,
5959
LCCohereEmbeddings,
60+
LCGoogleEmbeddings,
6061
LCHuggingFaceEmbeddings,
6162
OpenAIEmbeddings,
6263
TeiEndpointEmbeddings,
@@ -68,6 +69,7 @@ def load_vendors(self):
6869
FastEmbedEmbeddings,
6970
LCCohereEmbeddings,
7071
LCHuggingFaceEmbeddings,
72+
LCGoogleEmbeddings,
7173
TeiEndpointEmbeddings,
7274
]
7375

libs/ktem/ktem/pages/setup.py

Lines changed: 57 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,10 @@
99
from theflow.settings import settings as flowsettings
1010

1111
KH_DEMO_MODE = getattr(flowsettings, "KH_DEMO_MODE", False)
12-
DEFAULT_OLLAMA_URL = "http://localhost:11434/api"
12+
KH_OLLAMA_URL = getattr(flowsettings, "KH_OLLAMA_URL", "http://localhost:11434/v1/")
13+
DEFAULT_OLLAMA_URL = KH_OLLAMA_URL.replace("v1", "api")
14+
if DEFAULT_OLLAMA_URL.endswith("/"):
15+
DEFAULT_OLLAMA_URL = DEFAULT_OLLAMA_URL[:-1]
1316

1417

1518
DEMO_MESSAGE = (
@@ -55,8 +58,9 @@ def on_building_ui(self):
5558
gr.Markdown(f"# Welcome to {self._app.app_name} first setup!")
5659
self.radio_model = gr.Radio(
5760
[
58-
("Cohere API (*free registration* available) - recommended", "cohere"),
59-
("OpenAI API (for more advance models)", "openai"),
61+
("Cohere API (*free registration*) - recommended", "cohere"),
62+
("Google API (*free registration*)", "google"),
63+
("OpenAI API (for GPT-based models)", "openai"),
6064
("Local LLM (for completely *private RAG*)", "ollama"),
6165
],
6266
label="Select your model provider",
@@ -92,6 +96,18 @@ def on_building_ui(self):
9296
show_label=False, placeholder="Cohere API Key"
9397
)
9498

99+
with gr.Column(visible=False) as self.google_option:
100+
gr.Markdown(
101+
(
102+
"#### Google API Key\n\n"
103+
"(register your free API key "
104+
"at https://aistudio.google.com/app/apikey)"
105+
)
106+
)
107+
self.google_api_key = gr.Textbox(
108+
show_label=False, placeholder="Google API Key"
109+
)
110+
95111
with gr.Column(visible=False) as self.ollama_option:
96112
gr.Markdown(
97113
(
@@ -119,7 +135,12 @@ def on_register_events(self):
119135
self.openai_api_key.submit,
120136
],
121137
fn=self.update_model,
122-
inputs=[self.cohere_api_key, self.openai_api_key, self.radio_model],
138+
inputs=[
139+
self.cohere_api_key,
140+
self.openai_api_key,
141+
self.google_api_key,
142+
self.radio_model,
143+
],
123144
outputs=[self.setup_log],
124145
show_progress="hidden",
125146
)
@@ -147,13 +168,19 @@ def on_register_events(self):
147168
fn=self.switch_options_view,
148169
inputs=[self.radio_model],
149170
show_progress="hidden",
150-
outputs=[self.cohere_option, self.openai_option, self.ollama_option],
171+
outputs=[
172+
self.cohere_option,
173+
self.openai_option,
174+
self.ollama_option,
175+
self.google_option,
176+
],
151177
)
152178

153179
def update_model(
154180
self,
155181
cohere_api_key,
156182
openai_api_key,
183+
google_api_key,
157184
radio_model_value,
158185
):
159186
# skip if KH_DEMO_MODE
@@ -221,12 +248,32 @@ def update_model(
221248
},
222249
default=True,
223250
)
251+
elif radio_model_value == "google":
252+
if google_api_key:
253+
llms.update(
254+
name="google",
255+
spec={
256+
"__type__": "kotaemon.llms.chats.LCGeminiChat",
257+
"model_name": "gemini-1.5-flash",
258+
"api_key": google_api_key,
259+
},
260+
default=True,
261+
)
262+
embeddings.update(
263+
name="google",
264+
spec={
265+
"__type__": "kotaemon.embeddings.LCGoogleEmbeddings",
266+
"model": "models/text-embedding-004",
267+
"google_api_key": google_api_key,
268+
},
269+
default=True,
270+
)
224271
elif radio_model_value == "ollama":
225272
llms.update(
226273
name="ollama",
227274
spec={
228275
"__type__": "kotaemon.llms.ChatOpenAI",
229-
"base_url": "http://localhost:11434/v1/",
276+
"base_url": KH_OLLAMA_URL,
230277
"model": "llama3.1:8b",
231278
"api_key": "ollama",
232279
},
@@ -236,7 +283,7 @@ def update_model(
236283
name="ollama",
237284
spec={
238285
"__type__": "kotaemon.embeddings.OpenAIEmbeddings",
239-
"base_url": "http://localhost:11434/v1/",
286+
"base_url": KH_OLLAMA_URL,
240287
"model": "nomic-embed-text",
241288
"api_key": "ollama",
242289
},
@@ -270,7 +317,7 @@ def update_model(
270317
yield log_content
271318
except Exception as e:
272319
log_content += (
273-
"Make sure you have download and installed Ollama correctly."
320+
"Make sure you have download and installed Ollama correctly. "
274321
f"Got error: {str(e)}"
275322
)
276323
yield log_content
@@ -345,9 +392,9 @@ def update_default_settings(self, radio_model_value, default_settings):
345392
return default_settings
346393

347394
def switch_options_view(self, radio_model_value):
348-
components_visible = [gr.update(visible=False) for _ in range(3)]
395+
components_visible = [gr.update(visible=False) for _ in range(4)]
349396

350-
values = ["cohere", "openai", "ollama", None]
397+
values = ["cohere", "openai", "ollama", "google", None]
351398
assert radio_model_value in values, f"Invalid value {radio_model_value}"
352399

353400
if radio_model_value is not None:

0 commit comments

Comments
 (0)