From 7791a124c5c6160b09acb8421486941e523a286f Mon Sep 17 00:00:00 2001 From: Jacopo Chevallard Date: Mon, 23 Sep 2024 17:30:02 +0200 Subject: [PATCH] fix: correctly returning the model metadata, used in chat-with-llm mode --- .../modules/rag_service/rag_service.py | 29 +++++++++++++++++-- 1 file changed, 26 insertions(+), 3 deletions(-) diff --git a/backend/api/quivr_api/modules/rag_service/rag_service.py b/backend/api/quivr_api/modules/rag_service/rag_service.py index ca0240e16140..c1f3ee7da6a3 100644 --- a/backend/api/quivr_api/modules/rag_service/rag_service.py +++ b/backend/api/quivr_api/modules/rag_service/rag_service.py @@ -2,11 +2,12 @@ import os from uuid import UUID, uuid4 +from quivr_api.utils.uuid_generator import generate_uuid_from_string from quivr_core.brain import Brain as BrainCore from quivr_core.chat import ChatHistory as ChatHistoryCore from quivr_core.config import LLMEndpointConfig, RetrievalConfig from quivr_core.llm.llm_endpoint import LLMEndpoint -from quivr_core.models import ParsedRAGResponse, RAGResponseMetadata +from quivr_core.models import ChatLLMMetadata, ParsedRAGResponse, RAGResponseMetadata from quivr_core.quivr_rag_langgraph import QuivrQARAGLangGraph from quivr_api.logger import get_logger @@ -262,6 +263,9 @@ async def generate_answer_stream( llm = self.get_llm(retrieval_config) + # Get model metadata + model_metadata = await self.model_service.get_model(self.brain.name) + brain_core = BrainCore( name=self.brain.name, id=self.brain.id, @@ -281,10 +285,22 @@ async def generate_answer_stream( "user_message": question, # TODO: define result "message_time": datetime.datetime.now(), # TODO: define result "prompt_title": (self.prompt.title if self.prompt else ""), - "brain_name": self.brain.name if self.brain else None, - "brain_id": self.brain.brain_id if self.brain else None, + # brain_name and brain_id must be None in the chat-with-llm case, as this will force the front to look for the model_metadata + "brain_name": self.brain.name if self.brain_service else None, + "brain_id": self.brain.brain_id if self.brain_service else None, } + metadata_model = {} + if model_metadata: + metadata_model = ChatLLMMetadata( + name=self.brain.name, + description=model_metadata.description, + image_url=model_metadata.image_url, + display_name=model_metadata.display_name, + brain_id=str(generate_uuid_from_string(self.brain.name)), + brain_name=self.model_to_use, + ) + async for response in brain_core.ask_streaming( question=question, retrieval_config=retrieval_config, @@ -306,6 +322,10 @@ async def generate_answer_stream( streamed_chat_history.metadata["snippet_emoji"] = ( self.brain.snippet_emoji if self.brain else None ) + if metadata_model: + streamed_chat_history.metadata["metadata_model"] = ( + metadata_model + ) full_answer += response.answer yield f"data: {streamed_chat_history.model_dump_json()}" @@ -315,6 +335,7 @@ async def generate_answer_stream( metadata=response.metadata.model_dump(), **message_metadata, ) + if streamed_chat_history.metadata: streamed_chat_history.metadata["snippet_color"] = ( self.brain.snippet_color if self.brain else None @@ -322,6 +343,8 @@ async def generate_answer_stream( streamed_chat_history.metadata["snippet_emoji"] = ( self.brain.snippet_emoji if self.brain else None ) + if metadata_model: + streamed_chat_history.metadata["metadata_model"] = metadata_model sources_urls = ( await generate_source(