-
Notifications
You must be signed in to change notification settings - Fork 5.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Make a notebook. Clean up environment.
- Loading branch information
Showing
3 changed files
with
317 additions
and
193 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,317 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"# SQL Agent" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"This notebook demonstrates a basic SQL agent that translates natural language questions into SQL queries." | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Environment\n", | ||
"\n", | ||
"For this demo, we use a SQLite database environment based on a standard text-to-sql benchmark called [Spider](https://yale-lily.github.io/spider). The environment provides a gym-like interface and can be used as follows." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 13, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"Loading cached Spider dataset from /home/wangdazhang/.cache/spider\n", | ||
"Schema file not found for /home/wangdazhang/.cache/spider/spider/database/flight_4\n", | ||
"Schema file not found for /home/wangdazhang/.cache/spider/spider/database/small_bank_1\n", | ||
"Schema file not found for /home/wangdazhang/.cache/spider/spider/database/icfp_1\n", | ||
"Schema file not found for /home/wangdazhang/.cache/spider/spider/database/twitter_1\n", | ||
"Schema file not found for /home/wangdazhang/.cache/spider/spider/database/epinions_1\n", | ||
"Schema file not found for /home/wangdazhang/.cache/spider/spider/database/chinook_1\n", | ||
"Schema file not found for /home/wangdazhang/.cache/spider/spider/database/company_1\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"# %pip install spider-env\n", | ||
"\n", | ||
"from spider_env import SpiderEnv\n", | ||
"\n", | ||
"gym = SpiderEnv()\n", | ||
"\n", | ||
"# Randomly select a question from Spider\n", | ||
"observation, info = gym.reset()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 14, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"Find the famous titles of artists that do not have any volume.\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"# The natural language question\n", | ||
"question = observation[\"instruction\"]\n", | ||
"print(question)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 15, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"CREATE TABLE \"artist\" (\n", | ||
"\"Artist_ID\" int,\n", | ||
"\"Artist\" text,\n", | ||
"\"Age\" int,\n", | ||
"\"Famous_Title\" text,\n", | ||
"\"Famous_Release_date\" text,\n", | ||
"PRIMARY KEY (\"Artist_ID\")\n", | ||
");\n", | ||
"CREATE TABLE \"volume\" (\n", | ||
"\"Volume_ID\" int,\n", | ||
"\"Volume_Issue\" text,\n", | ||
"\"Issue_Date\" text,\n", | ||
"\"Weeks_on_Top\" real,\n", | ||
"\"Song\" text,\n", | ||
"\"Artist_ID\" int,\n", | ||
"PRIMARY KEY (\"Volume_ID\"),\n", | ||
"FOREIGN KEY (\"Artist_ID\") REFERENCES \"artist\"(\"Artist_ID\")\n", | ||
");\n", | ||
"CREATE TABLE \"music_festival\" (\n", | ||
"\"ID\" int,\n", | ||
"\"Music_Festival\" text,\n", | ||
"\"Date_of_ceremony\" text,\n", | ||
"\"Category\" text,\n", | ||
"\"Volume\" int,\n", | ||
"\"Result\" text,\n", | ||
"PRIMARY KEY (\"ID\"),\n", | ||
"FOREIGN KEY (\"Volume\") REFERENCES \"volume\"(\"Volume_ID\")\n", | ||
");\n", | ||
"\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"# The schema of the corresponding database\n", | ||
"schema = info[\"schema\"]\n", | ||
"print(schema)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Agent Implementation\n", | ||
"\n", | ||
"Using AutoGen, a SQL agent can be implemented with a ConversableAgent. The gym environment executes the generated SQL query and the agent can take execution results as feedback to improve its generation in multiple rounds of conversations." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 16, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from autogen import UserProxyAgent, ConversableAgent, config_list_from_json\n", | ||
"from typing import Annotated, Dict\n", | ||
"import json\n", | ||
"import os\n", | ||
"\n", | ||
"os.environ[\"AUTOGEN_USE_DOCKER\"] = \"False\"\n", | ||
"config_list = config_list_from_json(env_or_file=\"OAI_CONFIG_LIST\")\n", | ||
"\n", | ||
"\n", | ||
"def check_termination(msg: Dict):\n", | ||
" if \"tool_responses\" not in msg:\n", | ||
" return False\n", | ||
" json_str = msg[\"tool_responses\"][0][\"content\"]\n", | ||
" obj = json.loads(json_str)\n", | ||
" return \"error\" not in obj or obj[\"error\"] is None and obj[\"reward\"] == 1\n", | ||
"\n", | ||
"\n", | ||
"sql_writer = ConversableAgent(\n", | ||
" \"sql_writer\",\n", | ||
" llm_config={\"config_list\": config_list},\n", | ||
" system_message=\"You are good at writing SQL queries. Always respond with a function call to execute_sql().\",\n", | ||
" is_termination_msg=check_termination,\n", | ||
")\n", | ||
"user_proxy = UserProxyAgent(\"user_proxy\", human_input_mode=\"NEVER\", max_consecutive_auto_reply=5)\n", | ||
"\n", | ||
"\n", | ||
"@sql_writer.register_for_llm(description=\"Function for executing SQL query and returning a response\")\n", | ||
"@user_proxy.register_for_execution()\n", | ||
"def execute_sql(\n", | ||
" reflection: Annotated[str, \"Think about what to do\"], sql: Annotated[str, \"SQL query\"]\n", | ||
") -> Annotated[Dict[str, str], \"Dictionary with keys 'result' and 'error'\"]:\n", | ||
" observation, reward, _, _, info = gym.step(sql)\n", | ||
" error = observation[\"feedback\"][\"error\"]\n", | ||
" if not error and reward == 0:\n", | ||
" error = \"The SQL query returned an incorrect result\"\n", | ||
" if error:\n", | ||
" return {\n", | ||
" \"error\": error,\n", | ||
" \"wrong_result\": observation[\"feedback\"][\"result\"],\n", | ||
" \"correct_result\": info[\"gold_result\"],\n", | ||
" }\n", | ||
" else:\n", | ||
" return {\n", | ||
" \"result\": observation[\"feedback\"][\"result\"],\n", | ||
" }" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"The agent can then take as input the schema and the text question, and generate the SQL query." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 17, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"\u001b[33muser_proxy\u001b[0m (to sql_writer):\n", | ||
"\n", | ||
"Below is the schema for a SQL database:\n", | ||
"CREATE TABLE \"artist\" (\n", | ||
"\"Artist_ID\" int,\n", | ||
"\"Artist\" text,\n", | ||
"\"Age\" int,\n", | ||
"\"Famous_Title\" text,\n", | ||
"\"Famous_Release_date\" text,\n", | ||
"PRIMARY KEY (\"Artist_ID\")\n", | ||
");\n", | ||
"CREATE TABLE \"volume\" (\n", | ||
"\"Volume_ID\" int,\n", | ||
"\"Volume_Issue\" text,\n", | ||
"\"Issue_Date\" text,\n", | ||
"\"Weeks_on_Top\" real,\n", | ||
"\"Song\" text,\n", | ||
"\"Artist_ID\" int,\n", | ||
"PRIMARY KEY (\"Volume_ID\"),\n", | ||
"FOREIGN KEY (\"Artist_ID\") REFERENCES \"artist\"(\"Artist_ID\")\n", | ||
");\n", | ||
"CREATE TABLE \"music_festival\" (\n", | ||
"\"ID\" int,\n", | ||
"\"Music_Festival\" text,\n", | ||
"\"Date_of_ceremony\" text,\n", | ||
"\"Category\" text,\n", | ||
"\"Volume\" int,\n", | ||
"\"Result\" text,\n", | ||
"PRIMARY KEY (\"ID\"),\n", | ||
"FOREIGN KEY (\"Volume\") REFERENCES \"volume\"(\"Volume_ID\")\n", | ||
");\n", | ||
"\n", | ||
"Generate a SQL query to answer the following question:\n", | ||
"Find the famous titles of artists that do not have any volume.\n", | ||
"\n", | ||
"\n", | ||
"--------------------------------------------------------------------------------\n", | ||
"\u001b[31m\n", | ||
">>>>>>>> USING AUTO REPLY...\u001b[0m\n", | ||
"\u001b[33msql_writer\u001b[0m (to user_proxy):\n", | ||
"\n", | ||
"\u001b[32m***** Suggested tool Call (call_eAu0OEzS8l3QvN3jQSn4w0hJ): execute_sql *****\u001b[0m\n", | ||
"Arguments: \n", | ||
"{\"reflection\":\"Generating SQL to find famous titles of artists without any volume\",\"sql\":\"SELECT a.Artist, a.Famous_Title FROM artist a WHERE NOT EXISTS (SELECT 1 FROM volume v WHERE v.Artist_ID = a.Artist_ID)\"}\n", | ||
"\u001b[32m****************************************************************************\u001b[0m\n", | ||
"\n", | ||
"--------------------------------------------------------------------------------\n", | ||
"\u001b[35m\n", | ||
">>>>>>>> EXECUTING FUNCTION execute_sql...\u001b[0m\n", | ||
"\u001b[33muser_proxy\u001b[0m (to sql_writer):\n", | ||
"\n", | ||
"\u001b[33muser_proxy\u001b[0m (to sql_writer):\n", | ||
"\n", | ||
"\u001b[32m***** Response from calling tool \"call_eAu0OEzS8l3QvN3jQSn4w0hJ\" *****\u001b[0m\n", | ||
"{\"error\": \"The SQL query returned an incorrect result\", \"wrong_result\": [[\"Ophiolatry\", \"Antievangelistical Process (re-release)\"], [\"Triumfall\", \"Antithesis of All Flesh\"]], \"correct_result\": [[\"Antievangelistical Process (re-release)\"], [\"Antithesis of All Flesh\"]]}\n", | ||
"\u001b[32m**********************************************************************\u001b[0m\n", | ||
"\n", | ||
"--------------------------------------------------------------------------------\n", | ||
"\u001b[31m\n", | ||
">>>>>>>> USING AUTO REPLY...\u001b[0m\n", | ||
"\u001b[33msql_writer\u001b[0m (to user_proxy):\n", | ||
"\n", | ||
"\u001b[32m***** Suggested tool Call (call_5LXoKqdZ17kPCOHJbbpSz2yk): execute_sql *****\u001b[0m\n", | ||
"Arguments: \n", | ||
"{\"reflection\":\"Adjusting SQL to only select famous titles and exclude artist names for artists without any volume.\",\"sql\":\"SELECT a.Famous_Title FROM artist a WHERE NOT EXISTS (SELECT 1 FROM volume v WHERE v.Artist_ID = a.Artist_ID)\"}\n", | ||
"\u001b[32m****************************************************************************\u001b[0m\n", | ||
"\n", | ||
"--------------------------------------------------------------------------------\n", | ||
"\u001b[35m\n", | ||
">>>>>>>> EXECUTING FUNCTION execute_sql...\u001b[0m\n", | ||
"\u001b[33muser_proxy\u001b[0m (to sql_writer):\n", | ||
"\n", | ||
"\u001b[33muser_proxy\u001b[0m (to sql_writer):\n", | ||
"\n", | ||
"\u001b[32m***** Response from calling tool \"call_5LXoKqdZ17kPCOHJbbpSz2yk\" *****\u001b[0m\n", | ||
"{\"result\": [[\"Antievangelistical Process (re-release)\"], [\"Antithesis of All Flesh\"]]}\n", | ||
"\u001b[32m**********************************************************************\u001b[0m\n", | ||
"\n", | ||
"--------------------------------------------------------------------------------\n", | ||
"\u001b[31m\n", | ||
">>>>>>>> NO HUMAN INPUT RECEIVED.\u001b[0m\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"message = f\"\"\"Below is the schema for a SQL database:\n", | ||
"{schema}\n", | ||
"Generate a SQL query to answer the following question:\n", | ||
"{question}\n", | ||
"\"\"\"\n", | ||
"\n", | ||
"user_proxy.initiate_chat(sql_writer, message=message)" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": ".venv", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.9.18" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |
Oops, something went wrong.