From 143eb7d547351b28520e646826c22e335d950f22 Mon Sep 17 00:00:00 2001 From: ElishaKay Date: Fri, 1 Nov 2024 03:15:15 +0000 Subject: [PATCH] pulling nextjs chat settings from config --- backend/chat/chat.py | 32 ++++++++++++++++++++++++-------- 1 file changed, 24 insertions(+), 8 deletions(-) diff --git a/backend/chat/chat.py b/backend/chat/chat.py index a1be8e936..49c0cf739 100644 --- a/backend/chat/chat.py +++ b/backend/chat/chat.py @@ -28,20 +28,36 @@ def __init__( def create_agent(self): """Create React Agent Graph""" - #If not vector store, split and talk to the report - llm_provider_name = getattr(self.config, "llm_provider") - fast_llm_model = getattr(self.config, "fast_llm_model") - temperature = getattr(self.config, "temperature") - fast_token_limit = getattr(self.config, "fast_token_limit") + cfg = Config() - provider = get_llm(llm_provider_name, model=fast_llm_model, temperature=temperature, max_tokens=fast_token_limit, **self.config.llm_kwargs).llm + # Retrieve LLM using get_llm with settings from config + provider = get_llm( + llm_provider=cfg.smart_llm_provider, + model=cfg.smart_llm_model, + temperature=0.35, + max_tokens=cfg.smart_token_limit, + **self.config.llm_kwargs + ).llm + + # If vector_store is not initialized, process documents and add to vector_store if not self.vector_store: documents = self._process_document(self.report) self.chat_config = {"configurable": {"thread_id": str(uuid.uuid4())}} - self.embedding = Memory(getattr(self.config, 'embedding_provider', None), self.headers).get_embeddings() + self.embedding = Memory( + cfg.embedding_provider, + cfg.embedding_model, + **cfg.embedding_kwargs + ).get_embeddings() self.vector_store = InMemoryVectorStore(self.embedding) self.vector_store.add_texts(documents) - graph = create_react_agent(provider, tools=[self.vector_store_tool(self.vector_store)], checkpointer=MemorySaver()) + + # Create the React Agent Graph with the configured provider + graph = create_react_agent( + provider, + tools=[self.vector_store_tool(self.vector_store)], + checkpointer=MemorySaver() + ) + return graph def vector_store_tool(self, vector_store) -> Tool: