Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Myscale/sql self query #5

Merged
merged 4 commits into from
Jul 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions langchain/chains/sql_database/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ class SQLDatabaseChain(Chain):
query_checker_prompt: Optional[BasePromptTemplate] = None
"""The prompt template that should be used by the query checker"""
sql_cmd_parser: Optional[BaseOutputParser] = None
"""Output parser that reformat the generated SQL"""
native_format: bool = False
"""If return_direct, controls whether to return in python native format instead of strings"""

class Config:
"""Configuration for this pydantic object."""
Expand Down Expand Up @@ -127,7 +130,8 @@ def _call(
intermediate_steps.append({"sql_cmd": sql_cmd}) # input: sql exec
if self.sql_cmd_parser:
sql_cmd = self.sql_cmd_parser.parse(sql_cmd)
result = self.database.run(sql_cmd)
result = self.database.run(sql_cmd,
native_format=self.native_format if self.return_direct else False)
intermediate_steps.append(str(result)) # output: sql exec
else:
query_checker_prompt = self.query_checker_prompt or PromptTemplate(
Expand All @@ -154,7 +158,8 @@ def _call(
) # input: sql exec
if self.sql_cmd_parser:
checked_sql_command = self.sql_cmd_parser.parse(checked_sql_command)
result = self.database.run(checked_sql_command)
result = self.database.run(checked_sql_command,
native_format=self.native_format if self.return_direct else False)
intermediate_steps.append(str(result)) # output: sql exec
sql_cmd = checked_sql_command

Expand Down
2 changes: 2 additions & 0 deletions langchain/retrievers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from langchain.retrievers.wikipedia import WikipediaRetriever
from langchain.retrievers.zep import ZepRetriever
from langchain.retrievers.zilliz import ZillizRetriever
from langchain.retrievers.sql_database import SQLDatabaseChainRetriever

__all__ = [
"AmazonKendraRetriever",
Expand Down Expand Up @@ -58,4 +59,5 @@
"ZepRetriever",
"ZillizRetriever",
"DocArrayRetriever",
"SQLDatabaseChainRetriever",
]
32 changes: 32 additions & 0 deletions langchain/retrievers/sql_database.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
"""SQL Database Chain Retriever"""
from typing import Any, List, Coroutine
from pydantic import Field

from langchain.callbacks.manager import (
CallbackManagerForRetrieverRun,
)

from pydantic import BaseModel

from langchain.chains.sql_database.base import SQLDatabaseChain
from langchain.schema import BaseRetriever, Document

from langchain.schema.document import Document
from langchain.sql_database import SQLDatabase

class SQLDatabaseChainRetriever(BaseRetriever):
"""Retriever that uses SQLDatabase as Retriever"""

sql_db_chain: SQLDatabaseChain
"""SQL Database Chain"""
page_content_key: str = "content"
"""column name for page content of documents"""

def _get_relevant_documents(self, query: str, run_manager: CallbackManagerForRetrieverRun) -> List[Document]:
if not self.sql_db_chain.native_format:
raise TypeError("SQL Database Chain must return in native format. Try to turn `native_format` in this chain to `True`.")
ret = self.sql_db_chain.run(query=query, callbacks=run_manager)
return [Document(page_content=r[self.page_content_key], metadata=r) for r in ret]

async def _aget_relevant_documents(self, query: str, *args: Any, **kwargs: Any) -> Coroutine[Any, Any, List[Document]]:
raise NotImplementedError
23 changes: 18 additions & 5 deletions langchain/sql_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,11 @@ def _format_index(index: sqlalchemy.engine.interfaces.ReflectedIndex) -> str:
f' Columns: {str(index["column_names"])}'
)

def _try_eval(x):
try:
return eval(x)
except:
return x

def truncate_word(content: Any, *, length: int, suffix: str = "...") -> str:
"""
Expand Down Expand Up @@ -336,11 +341,13 @@ def _get_sample_rows(self, table: Table) -> str:
f"{sample_rows_str}"
)

def run(self, command: str, fetch: str = "all") -> str:
def run(self, command: str, fetch: str = "all", native_format: bool = False) -> str:
"""Execute a SQL command and return a string representing the results.

If the statement returns rows, a string of the results is returned.
If the statement returns no rows, an empty string is returned.

If return_direct is set to true, then the result will be directly returned,
Otherwise:
If the statement returns rows, a string of the results is returned.
If the statement returns no rows, an empty string is returned.

"""
with self._engine.begin() as connection:
Expand All @@ -361,7 +368,13 @@ def run(self, command: str, fetch: str = "all") -> str:
result = cursor.fetchone() # type: ignore
else:
raise ValueError("Fetch parameter must be either 'one' or 'all'")


# If return_direct then directly return the result
if native_format:
if isinstance(result, list):
return [{k: _try_eval(v) for k, v in dict(d).items()} for d in result]
return {k: _try_eval(v) for k, v in dict(result).items()}

# Convert columns values to string to avoid issues with sqlalchmey
# trunacating text
if isinstance(result, list):
Expand Down