Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add SQL agent and Spider environment #1218

Merged
merged 13 commits into from
Jan 27, 2024
316 changes: 316 additions & 0 deletions notebook/agentchat_sql_spider.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,316 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# SQL Agent for Spider text-to-SQL benchmark"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This notebook demonstrates a basic SQL agent that translates natural language questions into SQL queries."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Environment\n",
"\n",
"For this demo, we use a SQLite database environment based on a standard text-to-sql benchmark called [Spider](https://yale-lily.github.io/spider). The environment provides a gym-like interface and can be used as follows."
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Loading cached Spider dataset from /home/wangdazhang/.cache/spider\n",
"Schema file not found for /home/wangdazhang/.cache/spider/spider/database/flight_4\n",
"Schema file not found for /home/wangdazhang/.cache/spider/spider/database/small_bank_1\n",
"Schema file not found for /home/wangdazhang/.cache/spider/spider/database/icfp_1\n",
"Schema file not found for /home/wangdazhang/.cache/spider/spider/database/twitter_1\n",
"Schema file not found for /home/wangdazhang/.cache/spider/spider/database/epinions_1\n",
"Schema file not found for /home/wangdazhang/.cache/spider/spider/database/chinook_1\n",
"Schema file not found for /home/wangdazhang/.cache/spider/spider/database/company_1\n"
]
}
],
"source": [
"# %pip install spider-env\n",
"from spider_env import SpiderEnv\n",
"\n",
"from autogen import UserProxyAgent, ConversableAgent, config_list_from_json\n",
"from typing import Annotated, Dict\n",
"import json\n",
"import os\n",
"\n",
"gym = SpiderEnv()\n",
"\n",
"# Randomly select a question from Spider\n",
"observation, info = gym.reset()"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Find the famous titles of artists that do not have any volume.\n"
]
}
],
"source": [
"# The natural language question\n",
"question = observation[\"instruction\"]\n",
"print(question)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CREATE TABLE \"artist\" (\n",
"\"Artist_ID\" int,\n",
"\"Artist\" text,\n",
"\"Age\" int,\n",
"\"Famous_Title\" text,\n",
"\"Famous_Release_date\" text,\n",
"PRIMARY KEY (\"Artist_ID\")\n",
");\n",
"CREATE TABLE \"volume\" (\n",
"\"Volume_ID\" int,\n",
"\"Volume_Issue\" text,\n",
"\"Issue_Date\" text,\n",
"\"Weeks_on_Top\" real,\n",
"\"Song\" text,\n",
"\"Artist_ID\" int,\n",
"PRIMARY KEY (\"Volume_ID\"),\n",
"FOREIGN KEY (\"Artist_ID\") REFERENCES \"artist\"(\"Artist_ID\")\n",
");\n",
"CREATE TABLE \"music_festival\" (\n",
"\"ID\" int,\n",
"\"Music_Festival\" text,\n",
"\"Date_of_ceremony\" text,\n",
"\"Category\" text,\n",
"\"Volume\" int,\n",
"\"Result\" text,\n",
"PRIMARY KEY (\"ID\"),\n",
"FOREIGN KEY (\"Volume\") REFERENCES \"volume\"(\"Volume_ID\")\n",
");\n",
"\n"
]
}
],
"source": [
"# The schema of the corresponding database\n",
"schema = info[\"schema\"]\n",
"print(schema)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Agent Implementation\n",
"\n",
"Using AutoGen, a SQL agent can be implemented with a ConversableAgent. The gym environment executes the generated SQL query and the agent can take execution results as feedback to improve its generation in multiple rounds of conversations."
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
"os.environ[\"AUTOGEN_USE_DOCKER\"] = \"False\"\n",
"config_list = config_list_from_json(env_or_file=\"OAI_CONFIG_LIST\")\n",
"\n",
"\n",
"def check_termination(msg: Dict):\n",
" if \"tool_responses\" not in msg:\n",
" return False\n",
" json_str = msg[\"tool_responses\"][0][\"content\"]\n",
" obj = json.loads(json_str)\n",
" return \"error\" not in obj or obj[\"error\"] is None and obj[\"reward\"] == 1\n",
"\n",
"\n",
"sql_writer = ConversableAgent(\n",
" \"sql_writer\",\n",
" llm_config={\"config_list\": config_list},\n",
" system_message=\"You are good at writing SQL queries. Always respond with a function call to execute_sql().\",\n",
" is_termination_msg=check_termination,\n",
")\n",
"user_proxy = UserProxyAgent(\"user_proxy\", human_input_mode=\"NEVER\", max_consecutive_auto_reply=5)\n",
"\n",
"\n",
"@sql_writer.register_for_llm(description=\"Function for executing SQL query and returning a response\")\n",
"@user_proxy.register_for_execution()\n",
"def execute_sql(\n",
" reflection: Annotated[str, \"Think about what to do\"], sql: Annotated[str, \"SQL query\"]\n",
") -> Annotated[Dict[str, str], \"Dictionary with keys 'result' and 'error'\"]:\n",
" observation, reward, _, _, info = gym.step(sql)\n",
" error = observation[\"feedback\"][\"error\"]\n",
" if not error and reward == 0:\n",
" error = \"The SQL query returned an incorrect result\"\n",
" if error:\n",
" return {\n",
" \"error\": error,\n",
" \"wrong_result\": observation[\"feedback\"][\"result\"],\n",
" \"correct_result\": info[\"gold_result\"],\n",
" }\n",
" else:\n",
" return {\n",
" \"result\": observation[\"feedback\"][\"result\"],\n",
" }"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The agent can then take as input the schema and the text question, and generate the SQL query."
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[33muser_proxy\u001b[0m (to sql_writer):\n",
"\n",
"Below is the schema for a SQL database:\n",
"CREATE TABLE \"artist\" (\n",
"\"Artist_ID\" int,\n",
"\"Artist\" text,\n",
"\"Age\" int,\n",
"\"Famous_Title\" text,\n",
"\"Famous_Release_date\" text,\n",
"PRIMARY KEY (\"Artist_ID\")\n",
");\n",
"CREATE TABLE \"volume\" (\n",
"\"Volume_ID\" int,\n",
"\"Volume_Issue\" text,\n",
"\"Issue_Date\" text,\n",
"\"Weeks_on_Top\" real,\n",
"\"Song\" text,\n",
"\"Artist_ID\" int,\n",
"PRIMARY KEY (\"Volume_ID\"),\n",
"FOREIGN KEY (\"Artist_ID\") REFERENCES \"artist\"(\"Artist_ID\")\n",
");\n",
"CREATE TABLE \"music_festival\" (\n",
"\"ID\" int,\n",
"\"Music_Festival\" text,\n",
"\"Date_of_ceremony\" text,\n",
"\"Category\" text,\n",
"\"Volume\" int,\n",
"\"Result\" text,\n",
"PRIMARY KEY (\"ID\"),\n",
"FOREIGN KEY (\"Volume\") REFERENCES \"volume\"(\"Volume_ID\")\n",
");\n",
"\n",
"Generate a SQL query to answer the following question:\n",
"Find the famous titles of artists that do not have any volume.\n",
"\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[31m\n",
">>>>>>>> USING AUTO REPLY...\u001b[0m\n",
"\u001b[33msql_writer\u001b[0m (to user_proxy):\n",
"\n",
"\u001b[32m***** Suggested tool Call (call_eAu0OEzS8l3QvN3jQSn4w0hJ): execute_sql *****\u001b[0m\n",
"Arguments: \n",
"{\"reflection\":\"Generating SQL to find famous titles of artists without any volume\",\"sql\":\"SELECT a.Artist, a.Famous_Title FROM artist a WHERE NOT EXISTS (SELECT 1 FROM volume v WHERE v.Artist_ID = a.Artist_ID)\"}\n",
"\u001b[32m****************************************************************************\u001b[0m\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[35m\n",
">>>>>>>> EXECUTING FUNCTION execute_sql...\u001b[0m\n",
"\u001b[33muser_proxy\u001b[0m (to sql_writer):\n",
"\n",
"\u001b[33muser_proxy\u001b[0m (to sql_writer):\n",
"\n",
"\u001b[32m***** Response from calling tool \"call_eAu0OEzS8l3QvN3jQSn4w0hJ\" *****\u001b[0m\n",
"{\"error\": \"The SQL query returned an incorrect result\", \"wrong_result\": [[\"Ophiolatry\", \"Antievangelistical Process (re-release)\"], [\"Triumfall\", \"Antithesis of All Flesh\"]], \"correct_result\": [[\"Antievangelistical Process (re-release)\"], [\"Antithesis of All Flesh\"]]}\n",
"\u001b[32m**********************************************************************\u001b[0m\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[31m\n",
">>>>>>>> USING AUTO REPLY...\u001b[0m\n",
"\u001b[33msql_writer\u001b[0m (to user_proxy):\n",
"\n",
"\u001b[32m***** Suggested tool Call (call_5LXoKqdZ17kPCOHJbbpSz2yk): execute_sql *****\u001b[0m\n",
"Arguments: \n",
"{\"reflection\":\"Adjusting SQL to only select famous titles and exclude artist names for artists without any volume.\",\"sql\":\"SELECT a.Famous_Title FROM artist a WHERE NOT EXISTS (SELECT 1 FROM volume v WHERE v.Artist_ID = a.Artist_ID)\"}\n",
"\u001b[32m****************************************************************************\u001b[0m\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[35m\n",
">>>>>>>> EXECUTING FUNCTION execute_sql...\u001b[0m\n",
"\u001b[33muser_proxy\u001b[0m (to sql_writer):\n",
"\n",
"\u001b[33muser_proxy\u001b[0m (to sql_writer):\n",
"\n",
"\u001b[32m***** Response from calling tool \"call_5LXoKqdZ17kPCOHJbbpSz2yk\" *****\u001b[0m\n",
"{\"result\": [[\"Antievangelistical Process (re-release)\"], [\"Antithesis of All Flesh\"]]}\n",
"\u001b[32m**********************************************************************\u001b[0m\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[31m\n",
">>>>>>>> NO HUMAN INPUT RECEIVED.\u001b[0m\n"
]
}
],
"source": [
"message = f\"\"\"Below is the schema for a SQL database:\n",
"{schema}\n",
"Generate a SQL query to answer the following question:\n",
"{question}\n",
"\"\"\"\n",
"\n",
"user_proxy.initiate_chat(sql_writer, message=message)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.18"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
1 change: 1 addition & 0 deletions website/docs/Examples.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ Links to notebook examples:
- Automated Task Solving with Code Generation, Execution & Debugging - [View Notebook](https://github.com/microsoft/autogen/blob/main/notebook/agentchat_auto_feedback_from_code_execution.ipynb)
- Automated Code Generation and Question Answering with Retrieval Augmented Agents - [View Notebook](https://github.com/microsoft/autogen/blob/main/notebook/agentchat_RetrieveChat.ipynb)
- Automated Code Generation and Question Answering with [Qdrant](https://qdrant.tech/) based Retrieval Augmented Agents - [View Notebook](https://github.com/microsoft/autogen/blob/main/notebook/agentchat_qdrant_RetrieveChat.ipynb)
- Natural Language Text to SQL Query using the [Spider](https://yale-lily.github.io/spider) Text-to-SQL Benchmark - [View Notebook](https://github.com/microsoft/autogen/blob/main/notebook/agentchat_sql_spider.ipynb)
sonichi marked this conversation as resolved.
Show resolved Hide resolved

1. **Multi-Agent Collaboration (>3 Agents)**

Expand Down
Loading