diff --git a/langchain/chains/sql_database/base.py b/langchain/chains/sql_database/base.py index a18963271e3df..457eb63f57059 100644 --- a/langchain/chains/sql_database/base.py +++ b/langchain/chains/sql_database/base.py @@ -52,6 +52,9 @@ class SQLDatabaseChain(Chain): query_checker_prompt: Optional[BasePromptTemplate] = None """The prompt template that should be used by the query checker""" sql_cmd_parser: Optional[BaseOutputParser] = None + """Output parser that reformat the generated SQL""" + native_format: bool = False + """If return_direct, controls whether to return in python native format instead of strings""" class Config: """Configuration for this pydantic object.""" @@ -127,7 +130,8 @@ def _call( intermediate_steps.append({"sql_cmd": sql_cmd}) # input: sql exec if self.sql_cmd_parser: sql_cmd = self.sql_cmd_parser.parse(sql_cmd) - result = self.database.run(sql_cmd) + result = self.database.run(sql_cmd, + native_format=self.native_format if self.return_direct else False) intermediate_steps.append(str(result)) # output: sql exec else: query_checker_prompt = self.query_checker_prompt or PromptTemplate( @@ -154,7 +158,8 @@ def _call( ) # input: sql exec if self.sql_cmd_parser: checked_sql_command = self.sql_cmd_parser.parse(checked_sql_command) - result = self.database.run(checked_sql_command) + result = self.database.run(checked_sql_command, + native_format=self.native_format if self.return_direct else False) intermediate_steps.append(str(result)) # output: sql exec sql_cmd = checked_sql_command diff --git a/langchain/retrievers/__init__.py b/langchain/retrievers/__init__.py index 137a8ae19ab26..2b1f5607fe9d8 100644 --- a/langchain/retrievers/__init__.py +++ b/langchain/retrievers/__init__.py @@ -29,6 +29,7 @@ from langchain.retrievers.wikipedia import WikipediaRetriever from langchain.retrievers.zep import ZepRetriever from langchain.retrievers.zilliz import ZillizRetriever +from langchain.retrievers.sql_database import SQLDatabaseChainRetriever __all__ = [ "AmazonKendraRetriever", @@ -58,4 +59,5 @@ "ZepRetriever", "ZillizRetriever", "DocArrayRetriever", + "SQLDatabaseChainRetriever", ] diff --git a/langchain/retrievers/sql_database.py b/langchain/retrievers/sql_database.py new file mode 100644 index 0000000000000..8bc22a7a383ac --- /dev/null +++ b/langchain/retrievers/sql_database.py @@ -0,0 +1,32 @@ +"""SQL Database Chain Retriever""" +from typing import Any, List, Coroutine +from pydantic import Field + +from langchain.callbacks.manager import ( + CallbackManagerForRetrieverRun, +) + +from pydantic import BaseModel + +from langchain.chains.sql_database.base import SQLDatabaseChain +from langchain.schema import BaseRetriever, Document + +from langchain.schema.document import Document +from langchain.sql_database import SQLDatabase + +class SQLDatabaseChainRetriever(BaseRetriever): + """Retriever that uses SQLDatabase as Retriever""" + + sql_db_chain: SQLDatabaseChain + """SQL Database Chain""" + page_content_key: str = "content" + """column name for page content of documents""" + + def _get_relevant_documents(self, query: str, run_manager: CallbackManagerForRetrieverRun) -> List[Document]: + if not self.sql_db_chain.native_format: + raise TypeError("SQL Database Chain must return in native format. Try to turn `native_format` in this chain to `True`.") + ret = self.sql_db_chain.run(query=query, callbacks=run_manager) + return [Document(page_content=r[self.page_content_key], metadata=r) for r in ret] + + async def _aget_relevant_documents(self, query: str, *args: Any, **kwargs: Any) -> Coroutine[Any, Any, List[Document]]: + raise NotImplementedError \ No newline at end of file diff --git a/langchain/sql_database.py b/langchain/sql_database.py index d3b92e359419c..7a903c4cf1554 100644 --- a/langchain/sql_database.py +++ b/langchain/sql_database.py @@ -21,6 +21,11 @@ def _format_index(index: sqlalchemy.engine.interfaces.ReflectedIndex) -> str: f' Columns: {str(index["column_names"])}' ) +def _try_eval(x): + try: + return eval(x) + except: + return x def truncate_word(content: Any, *, length: int, suffix: str = "...") -> str: """ @@ -336,11 +341,13 @@ def _get_sample_rows(self, table: Table) -> str: f"{sample_rows_str}" ) - def run(self, command: str, fetch: str = "all") -> str: + def run(self, command: str, fetch: str = "all", native_format: bool = False) -> str: """Execute a SQL command and return a string representing the results. - - If the statement returns rows, a string of the results is returned. - If the statement returns no rows, an empty string is returned. + + If return_direct is set to true, then the result will be directly returned, + Otherwise: + If the statement returns rows, a string of the results is returned. + If the statement returns no rows, an empty string is returned. """ with self._engine.begin() as connection: @@ -361,7 +368,13 @@ def run(self, command: str, fetch: str = "all") -> str: result = cursor.fetchone() # type: ignore else: raise ValueError("Fetch parameter must be either 'one' or 'all'") - + + # If return_direct then directly return the result + if native_format: + if isinstance(result, list): + return [{k: _try_eval(v) for k, v in dict(d).items()} for d in result] + return {k: _try_eval(v) for k, v in dict(result).items()} + # Convert columns values to string to avoid issues with sqlalchmey # trunacating text if isinstance(result, list):