1616
1717"""Docs Agent"""
1818
19- import os
20- import sys
19+ import typing
2120
2221from absl import logging
2322import google .api_core
2423import google .ai .generativelanguage as glm
2524from chromadb .utils import embedding_functions
2625
27- from docs_agent .storage .chroma import Chroma , Format , ChromaEnhanced
28- from docs_agent .models .palm import PaLM
26+ from docs_agent .storage .chroma import ChromaEnhanced
2927
3028from docs_agent .models .google_genai import Gemini
3129
32- from docs_agent .utilities .config import ProductConfig , ReadConfig , Input , Models
33- from docs_agent .models import tokenCount
30+ from docs_agent .utilities .config import ProductConfig , Models
3431from docs_agent .preprocess .splitters import markdown_splitter
3532
3633from docs_agent .preprocess .splitters .markdown_splitter import Section as Section
37- from docs_agent .utilities .helpers import get_project_path
38- from docs_agent .postprocess .docs_retriever import FullPage as FullPage
3934from docs_agent .postprocess .docs_retriever import SectionDistance as SectionDistance
4035from docs_agent .postprocess .docs_retriever import (
4136 SectionProbability as SectionProbability ,
@@ -49,6 +44,8 @@ class DocsAgent:
4944 def __init__ (self , config : ProductConfig , init_chroma : bool = True ):
5045 # Models settings
5146 self .config = config
47+ self .embedding_model = str (self .config .models .embedding_model )
48+ self .api_endpoint = str (self .config .models .api_endpoint )
5249 # Use the new chroma db for all queries
5350 # Should make a function for this or clean this behavior
5451 if init_chroma :
@@ -62,9 +59,9 @@ def __init__(self, config: ProductConfig, init_chroma: bool = True):
6259 )
6360 self .collection = self .chroma .get_collection (
6461 self .collection_name ,
65- embedding_model = self .config . models . embedding_model ,
62+ embedding_model = self .embedding_model ,
6663 embedding_function = embedding_function_gemini_retrieval (
67- self .config .models .api_key
64+ self .config .models .api_key , self . embedding_model
6865 ),
6966 )
7067 # AQA model settings
@@ -77,9 +74,12 @@ def __init__(self, config: ProductConfig, init_chroma: bool = True):
7774 self .context_model = "models/gemini-pro"
7875 gemini_model_config = Models (
7976 language_model = self .context_model ,
80- embedding_model = "models/embedding-001" ,
77+ embedding_model = self .embedding_model ,
78+ api_endpoint = self .api_endpoint ,
79+ )
80+ self .gemini = Gemini (
81+ models_config = gemini_model_config , conditions = config .conditions
8182 )
82- self .gemini = Gemini (models_config = gemini_model_config )
8383 # Semantic retriever
8484 if self .config .db_type == "google_semantic_retriever" :
8585 for item in self .config .db_configs :
@@ -93,9 +93,34 @@ def __init__(self, config: ProductConfig, init_chroma: bool = True):
9393 )
9494 self .aqa_response_buffer = ""
9595
96- if self .config .models .language_model == "models/gemini-pro" :
97- self .gemini = Gemini (models_config = config .models )
98- self .context_model = "models/gemini-pro"
96+ if self .config .models .language_model .startswith ("models/gemini" ):
97+ self .gemini = Gemini (
98+ models_config = config .models , conditions = config .conditions
99+ )
100+ self .context_model = self .config .models .language_model
101+
102+ # Always initialize the gemini-pro model for other tasks.
103+ gemini_pro_model_config = Models (
104+ language_model = "models/gemini-pro" ,
105+ embedding_model = self .embedding_model ,
106+ api_endpoint = self .api_endpoint ,
107+ )
108+ self .gemini_pro = Gemini (
109+ models_config = gemini_pro_model_config , conditions = config .conditions
110+ )
111+
112+ if self .config .app_mode == "1.5" :
113+ # Initialize the gemini-1.5.pro model for summarization.
114+ gemini_15_model_config = Models (
115+ language_model = "models/gemini-1.5-pro-latest" ,
116+ embedding_model = self .embedding_model ,
117+ api_endpoint = self .api_endpoint ,
118+ )
119+ self .gemini_15 = Gemini (
120+ models_config = gemini_15_model_config , conditions = config .conditions
121+ )
122+ else :
123+ self .gemini_15 = self .gemini_pro
99124
100125 # Use this method for talking to a Gemini content model
101126 def ask_content_model_with_context (self , context , question ):
@@ -261,7 +286,7 @@ def ask_aqa_model_using_corpora(self, question, answer_style: str = "VERBOSE"):
261286
262287 def ask_aqa_model (self , question ):
263288 response = ""
264- if self .db_type == "ONLINE_STORAGE " :
289+ if self .config . db_type == "google_semantic_retriever " :
265290 response = self .ask_aqa_model_using_corpora (question )
266291 else :
267292 response = self .ask_aqa_model_using_local_vector_store (question )
@@ -436,7 +461,11 @@ def query_vector_store_to_build(
436461 # If prompt is "fact_checker" it will use the fact_check_question from
437462 # config.yaml for the prompt
438463 def ask_content_model_with_context_prompt (
439- self , context : str , question : str , prompt : str = None
464+ self ,
465+ context : str ,
466+ question : str ,
467+ prompt : typing .Optional [str ] = None ,
468+ model : typing .Optional [str ] = None ,
440469 ):
441470 if prompt == None :
442471 prompt = self .config .conditions .condition_text
@@ -447,7 +476,13 @@ def ask_content_model_with_context_prompt(
447476 if self .config .log_level == "VERBOSE" :
448477 self .print_the_prompt (new_prompt )
449478 try :
450- response = self .gemini .generate_content (contents = new_prompt )
479+ response = ""
480+ if model == "gemini-pro" :
481+ response = self .gemini_pro .generate_content (contents = new_prompt )
482+ elif model == "gemini-1.5-pro" :
483+ response = self .gemini_15 .generate_content (contents = new_prompt )
484+ else :
485+ response = self .gemini .generate_content (contents = new_prompt )
451486 except :
452487 return self .config .conditions .model_error_message , new_prompt
453488 for chunk in response :
@@ -475,7 +510,7 @@ def ask_content_model_to_use_file(self, prompt: str, file: str):
475510 # Use this method for asking a Gemini content model for fact-checking.
476511 # This uses ask_content_model_with_context_prompt w
477512 def ask_content_model_to_fact_check_prompt (self , context : str , prev_response : str ):
478- question = self .fact_check_question + "\n \n Text: "
513+ question = self .config . conditions . fact_check_question + "\n \n Text: "
479514 question += prev_response
480515 return self .ask_content_model_with_context_prompt (
481516 context = context , question = question , prompt = ""
@@ -487,7 +522,7 @@ def generate_embedding(self, text, task_type: str = "SEMANTIC_SIMILARITY"):
487522
488523
489524# Function to give an embedding function for gemini using an API key
490- def embedding_function_gemini_retrieval (api_key ):
525+ def embedding_function_gemini_retrieval (api_key , embedding_model : str ):
491526 return embedding_functions .GoogleGenerativeAiEmbeddingFunction (
492- api_key = api_key , model_name = "models/embedding-001" , task_type = "RETRIEVAL_QUERY"
527+ api_key = api_key , model_name = embedding_model , task_type = "RETRIEVAL_QUERY"
493528 )
0 commit comments