diff --git a/src/backend/base/langflow/components/retrievers/SelfQueryRetriever.py b/src/backend/base/langflow/components/retrievers/SelfQueryRetriever.py index a46bf7bd812..3e6d6f696b3 100644 --- a/src/backend/base/langflow/components/retrievers/SelfQueryRetriever.py +++ b/src/backend/base/langflow/components/retrievers/SelfQueryRetriever.py @@ -4,7 +4,7 @@ from langchain_core.vectorstores import VectorStore from langflow.custom import CustomComponent -from langflow.field_typing import BaseLanguageModel +from langflow.field_typing import BaseLanguageModel, Text from langflow.schema import Record from langflow.schema.message import Message @@ -14,25 +14,54 @@ class SelfQueryRetrieverComponent(CustomComponent): description: str = "Retriever that uses a vector store and an LLM to generate the vector store queries." icon = "LangChain" + def build_config(self): + return { + "query": { + "display_name": "Query", + "input_types": ["Message", "Text"], + "info": "Query to be passed as input.", + }, + "vectorstore": { + "display_name": "Vector Store", + "info": "Vector Store to be passed as input.", + }, + "attribute_infos": { + "display_name": "Metadata Field Info", + "info": "Metadata Field Info to be passed as input.", + }, + "document_content_description": { + "display_name": "Document Content Description", + "info": "Document Content Description to be passed as input.", + }, + "llm": { + "display_name": "LLM", + "info": "LLM to be passed as input.", + }, + } + def build( self, query: Message, vectorstore: VectorStore, - metadata_field_info: list[AttributeInfo], - document_content_description: str, + attribute_infos: list[Record], + document_content_description: Text, llm: BaseLanguageModel, ) -> Record: - metadata_field_info = [i[0] for i in metadata_field_info] - + metadata_field_infos = [AttributeInfo(**record.data) for record in attribute_infos] self_query_retriever = SelfQueryRetriever.from_llm( - llm, - vectorstore, - document_content_description, - metadata_field_info, + llm=llm, + vectorstore=vectorstore, + document_contents=document_content_description, + metadata_field_info=metadata_field_infos, enable_limit=True, ) - input_text = query.text + if isinstance(query, Message): + input_text = query.text + elif isinstance(query, str): + input_text = query + else: + raise ValueError(f"Query type {type(query)} not supported.") documents = self_query_retriever.invoke(input=input_text) records = [Record.from_document(document) for document in documents] self.status = records