Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SQLToolManager.db_query_tool() missing 1 required positional argument: 'query'/'self' #1855

Open
5 tasks done
jason571 opened this issue Sep 26, 2024 · 5 comments
Open
5 tasks done

Comments

@jason571
Copy link

Checked other resources

  • I added a very descriptive title to this issue.
  • I searched the LangGraph/LangChain documentation with the integrated search.
  • I used the GitHub search to find a similar question and didn't find it.
  • I am sure that this is a bug in LangGraph/LangChain rather than my code.
  • I am sure this is better as an issue rather than a GitHub discussion, since this is a LangGraph bug and not a design question.

Example Code

https://github.com/langchain-ai/langgraph/blob/main/docs/docs/tutorials/sql-agent.ipynb

class SQLToolManager:
    def __init__(self):
        self.interface = Interface()
        self.db_handler = SQLiteHandler()
        self.llm = self.interface.get_current_model()
        self.toolkit = SQLDatabaseToolkit(db=self.db_handler.get_sql_database(), llm=self.llm)
        self.tools = self.toolkit.get_tools()
        self.list_tables_tool = next(tool for tool in self.tools if tool.name == "sql_db_list_tables")
        self.get_schema_tool = next(tool for tool in self.tools if tool.name == "sql_db_schema")

    def create_tool_node_with_fallback(self, tools: list) -> RunnableWithFallbacks[Any, dict]:
        """
        Create a ToolNode with a fallback to handle errors and surface them to the agent.
        """
        return ToolNode(tools).with_fallbacks(
            [RunnableLambda(self.handle_tool_error)], exception_key="error"
        )

    def handle_tool_error(self, state) -> dict:
        error = state.get("error")
        tool_calls = state["messages"][-1].tool_calls
        return {
            "messages": [
                ToolMessage(
                    content=f"Error: {repr(error)}\n please fix your mistakes.",
                    tool_call_id=tc["id"],
                )
                for tc in tool_calls
            ]
        }

    @tool
    def db_query_tool(self, query) -> str:
        """
        Execute a SQL query against the database and get back the result.
        If the query is not correct, an error message will be returned.
        """
        mylogging.info(f"Executing query: {query}")
        result = self.db_handler.db.run_no_throw(query)
        if not result:
            return "Error: Query failed. Please rewrite your query and try again."
        return result
test code
if __name__ == "__main__":
    manager = SQLToolManager()
    print(manager.list_tables())

    print(manager.get_schema_tool.invoke("Artist"))
    
    query = "SELECT * FROM Artist LIMIT 10;"
    #result = manager.db_query_tool.invoke({"query": query})
    result = manager.db_query_tool.invoke(query)
    print(result)
    
    result = manager.run_query(query)

Error Message and Stack Trace (if applicable)

result = manager.db_query_tool.invoke(query)  
File "/home/flyang/anaconda3/envs/LLMs/lib/python3.11/site-packages/langchain_core/tools/base.py", line 397, in invoke
    return self.run(tool_input, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/flyang/anaconda3/envs/LLMs/lib/python3.11/site-packages/langchain_core/tools/base.py", line 586, in run
    raise error_to_raise
  File "/home/flyang/anaconda3/envs/LLMs/lib/python3.11/site-packages/langchain_core/tools/base.py", line 555, in run
    response = context.run(self._run, *tool_args, **tool_kwargs)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/flyang/anaconda3/envs/LLMs/lib/python3.11/site-packages/langchain_core/tools/structured.py", line 69, in _run
    return self.func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: SQLToolManager.db_query_tool() missing 1 required positional argument: 'query'

result = manager.db_query_tool.invoke({"query": query})
File "/mnt/c/workspace/pr_train/LLMs/src/sqlAgent/sqlTools.py", line 142, in <module>
    result = manager.db_query_tool.invoke({"query": query})
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/flyang/anaconda3/envs/LLMs/lib/python3.11/site-packages/langchain_core/tools/base.py", line 397, in invoke
    return self.run(tool_input, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/flyang/anaconda3/envs/LLMs/lib/python3.11/site-packages/langchain_core/tools/base.py", line 586, in run
    raise error_to_raise
  File "/home/flyang/anaconda3/envs/LLMs/lib/python3.11/site-packages/langchain_core/tools/base.py", line 555, in run
    response = context.run(self._run, *tool_args, **tool_kwargs)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/flyang/anaconda3/envs/LLMs/lib/python3.11/site-packages/langchain_core/tools/structured.py", line 69, in _run
    return self.func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: SQLToolManager.db_query_tool() missing 1 required positional argument: 'self'

Description

Example Code
https://github.com/langchain-ai/langgraph/blob/main/docs/docs/tutorials/sql-agent.ipynb

System Info

System Information

OS: Linux
OS Version: #3672-Microsoft Fri Jan 01 08:00:00 PST 2016
Python Version: 3.11.9 (main, Apr 19 2024, 16:48:06) [GCC 11.2.0]

Package Information

langchain_core: 0.2.40
langchain: 0.2.16
langchain_community: 0.2.17
langsmith: 0.1.120
langchain_cohere: 0.1.9
langchain_experimental: 0.0.65
langchain_google_community: 1.0.7
langchain_huggingface: 0.0.3
langchain_milvus: 0.1.4
langchain_openai: 0.1.22
langchain_text_splitters: 0.2.4
langgraph: 0.2.22
langserve: 0.2.2

Other Dependencies

aiohttp: 3.10.3
async-timeout: 4.0.3
beautifulsoup4: 4.12.3
cohere: 5.8.1
dataclasses-json: 0.6.7
db-dtypes: Installed. No version info available.
fastapi: 0.112.0
gapic-google-longrunning: Installed. No version info available.
google-api-core: 2.19.1
google-api-python-client: 2.141.0
google-auth-httplib2: 0.2.0
google-auth-oauthlib: Installed. No version info available.
google-cloud-aiplatform: 1.63.0
google-cloud-bigquery: 3.25.0
google-cloud-bigquery-storage: Installed. No version info available.
google-cloud-contentwarehouse: Installed. No version info available.
google-cloud-discoveryengine: Installed. No version info available.
google-cloud-documentai: Installed. No version info available.
google-cloud-documentai-toolbox: Installed. No version info available.
google-cloud-speech: Installed. No version info available.
google-cloud-storage: 2.18.2
google-cloud-texttospeech: Installed. No version info available.
google-cloud-translate: Installed. No version info available.
google-cloud-vision: 3.7.4
googlemaps: Installed. No version info available.
grpcio: 1.63.0
httpx: 0.27.2
huggingface-hub: 0.24.5
jsonpatch: 1.33
langgraph-checkpoint: 1.0.9
numpy: 1.26.4
openai: 1.40.6
orjson: 3.10.7
packaging: 24.1
pandas: 2.2.2
pyarrow: 17.0.0
pydantic: 2.8.2
pymilvus: 2.4.6
pyproject-toml: 0.0.10
PyYAML: 6.0.2
requests: 2.32.3
scipy: 1.14.0
sentence-transformers: 3.0.1
SQLAlchemy: 2.0.32
sse-starlette: Installed. No version info available.
tabulate: 0.9.0
tenacity: 8.3.0
tiktoken: 0.7.0
tokenizers: 0.19.1
transformers: 4.44.0

@eyurtsev
Copy link
Contributor

Feels like an issue with @tool decorator being applied to a method potentially. If so it's a langchain-core issue.

But in the code snippet that you shared, why would this work:

    result = manager.run_query(query)

I don't see run_query defined anywhere?

@jason571
Copy link
Author

class SQLToolManager:
def init(self):
self.interface = Interface()
self.db_handler = SQLiteHandler()
self.llm = self.interface.get_current_model()
self.toolkit = SQLDatabaseToolkit(db=self.db_handler.get_sql_database(), llm=self.llm)
self.tools = self.toolkit.get_tools()
self.list_tables_tool = next(tool for tool in self.tools if tool.name == "sql_db_list_tables")
self.get_schema_tool = next(tool for tool in self.tools if tool.name == "sql_db_schema")

def create_tool_node_with_fallback(self, tools: list) -> RunnableWithFallbacks[Any, dict]:
    """
    Create a ToolNode with a fallback to handle errors and surface them to the agent.
    """
    return ToolNode(tools).with_fallbacks(
        [RunnableLambda(self.handle_tool_error)], exception_key="error"
    )

def handle_tool_error(self, state) -> dict:
    error = state.get("error")
    tool_calls = state["messages"][-1].tool_calls
    return {
        "messages": [
            ToolMessage(
                content=f"Error: {repr(error)}\n please fix your mistakes.",
                tool_call_id=tc["id"],
            )
            for tc in tool_calls
        ]
    }

@tool
def db_query_tool(self, query: str) -> str:
    """
    Execute a SQL query against the database and get back the result.
    If the query is not correct, an error message will be returned.
    """
    mylogging.info(f"Executing query: {query}")
    result = self.db_handler.db.run_no_throw(query)
    if not result:
        return "Error: Query failed. Please rewrite your query and try again."
    return result


def create_query_check(self):
    query_check_system = """You are a SQL expert with a strong attention to detail.
    Double check the SQLite query for common mistakes, including:
    - Using NOT IN with NULL values
    - Using UNION when UNION ALL should have been used
    - Using BETWEEN for exclusive ranges
    - Data type mismatch in predicates
    - Properly quoting identifiers
    - Using the correct number of arguments for functions
    - Casting to the correct data type
    - Using the proper columns for joins

    If there are any of the above mistakes, rewrite the query. If there are no mistakes, just reproduce the original query.
    You will call the appropriate tool to execute the query after running this check."""

    query_check_prompt = ChatPromptTemplate.from_messages(
        [("system", query_check_system), ("placeholder", "{messages}")]
    )
    return query_check_prompt | self.llm.bind_tools([self.db_query_tool], tool_choice="required")


def list_tables(self):
    return self.list_tables_tool.invoke("")

def get_schema(self, table_name: str):
    return self.get_schema_tool.invoke(table_name)

def run_query(self, query: str):
    return self.db_query_tool.invoke({"query": query})

def check_and_run_query(self, query: str):
    return self.query_check.invoke({"messages": [("user", query)]})

test code

if name == "main":

manager = SQLToolManager()
print(manager.list_tables())

print(manager.get_schema_tool.invoke("Artist"))

# db_query_tool 
input_query = "SELECT * FROM Artist LIMIT 10;"
result = manager.db_query_tool.invoke({"query": input_query})
print(result)

result = manager.run_query(input_query)
print(result)

# 检查并运行查询
result = manager.check_and_run_query(input_query)
print(result)
  File "/mnt/c/workspace/pr_train/LLMs/src/sqlAgent/sqlTools.py", line 145, in <module>
result = manager.db_query_tool.invoke({"query": input_query})
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

File "/home/flyang/anaconda3/envs/LLMs/lib/python3.11/site-packages/langchain_core/tools/base.py", line 489, in invoke
return self.run(tool_input, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/flyang/anaconda3/envs/LLMs/lib/python3.11/site-packages/langchain_core/tools/base.py", line 692, in run
raise error_to_raise
File "/home/flyang/anaconda3/envs/LLMs/lib/python3.11/site-packages/langchain_core/tools/base.py", line 655, in run
tool_args, tool_kwargs = self._to_args_and_kwargs(tool_input)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/flyang/anaconda3/envs/LLMs/lib/python3.11/site-packages/langchain_core/tools/base.py", line 578, in _to_args_and_kwargs
tool_input = self._parse_input(tool_input)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/flyang/anaconda3/envs/LLMs/lib/python3.11/site-packages/langchain_core/tools/base.py", line 520, in _parse_input
result = input_args.model_validate(tool_input)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/flyang/anaconda3/envs/LLMs/lib/python3.11/site-packages/pydantic/main.py", line 595, in model_validate
return cls.pydantic_validator.validate_python(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
pydantic_core._pydantic_core.ValidationError: 1 validation error for db_query_tool
self
Field required [type=missing, input_value={'query': 'SELECT * FROM Artist LIMIT 10;'}, input_type=dict]
For further information visit https://errors.pydantic.dev/2.9/v/missing

@jason571
Copy link
Author

System Information

OS: Linux
OS Version: #3672-Microsoft Fri Jan 01 08:00:00 PST 2016
Python Version: 3.11.9 (main, Apr 19 2024, 16:48:06) [GCC 11.2.0]

Package Information

langchain_core: 0.3.0
langchain: 0.3.0
langchain_community: 0.3.0
langsmith: 0.1.128
langchain_cli: 0.0.31
langchain_cohere: 0.3.0
langchain_experimental: 0.3.0
langchain_google_community: 2.0.0
langchain_huggingface: 0.1.0
langchain_milvus: 0.1.5
langchain_openai: 0.2.0
langchain_text_splitters: 0.3.0
langgraph: 0.2.22
langserve: 0.2.2

@jason571
Copy link
Author

input_query = "SELECT * FROM Artist LIMIT 10;"
result = manager.db_query_tool({"query": input_query})
File "/home/flyang/anaconda3/envs/LLMs/lib/python3.11/site-packages/pydantic/main.py", line 595, in model_validate
return cls.pydantic_validator.validate_python(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
pydantic_core._pydantic_core.ValidationError: 1 validation error for db_query_tool
self
Field required [type=missing, input_value={'query': 'SELECT * FROM Artist LIMIT 10;'}, input_type=dict]

  input_query = "SELECT * FROM Artist LIMIT 10;"
result = manager.db_query_tool(input_query)
#result = manager.db_query_tool.invoke({"query": input_query})
print(result)

  File "/home/flyang/anaconda3/envs/LLMs/lib/python3.11/site-packages/langchain_core/tools/base.py", line 513, in _parse_input
input_args.model_validate({key_: tool_input})

File "/home/flyang/anaconda3/envs/LLMs/lib/python3.11/site-packages/pydantic/main.py", line 595, in model_validate
return cls.pydantic_validator.validate_python(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
pydantic_core._pydantic_core.ValidationError: 1 validation error for db_query_tool
query
Field required [type=missing, input_value={'self': 'SELECT * FROM Artist LIMIT 10;'}, input_type=dict]
For further information visit https://errors.pydantic.dev/2.9/v/missing

@hinthornw
Copy link
Contributor

hinthornw commented Oct 2, 2024

@tool decorator not supported directly on methods right now (parent not bound at time of decoration)

I believe something like this works:

@property
def db_query_tool(self):
    @tool
    def query_db(query: str):
        """
        Execute a SQL query against the database and get back the result.
        If the query is not correct, an error message will be returned.
        """
        mylogging.info(f"Executing query: {query}")
        result = self.db_handler.db.run_no_throw(query)
        if not result:
            return "Error: Query failed. Please rewrite your query and try again."
        return result
   return query_db

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants