Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Harrison/sql rows #915

Merged
merged 2 commits into from
Feb 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 60 additions & 30 deletions docs/modules/chains/examples/sqlite.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "3d1e692e",
"metadata": {},
Expand Down Expand Up @@ -94,15 +93,15 @@
"\u001b[1m> Entering new SQLDatabaseChain chain...\u001b[0m\n",
"How many employees are there? \n",
"SQLQuery:\u001b[32;1m\u001b[1;3m SELECT COUNT(*) FROM Employee;\u001b[0m\n",
"SQLResult: \u001b[33;1m\u001b[1;3m[(9,)]\u001b[0m\n",
"Answer:\u001b[32;1m\u001b[1;3m There are 9 employees.\u001b[0m\n",
"SQLResult: \u001b[33;1m\u001b[1;3m[(8,)]\u001b[0m\n",
"Answer:\u001b[32;1m\u001b[1;3m There are 8 employees.\u001b[0m\n",
"\u001b[1m> Finished chain.\u001b[0m\n"
]
},
{
"data": {
"text/plain": [
"' There are 9 employees.'"
"' There are 8 employees.'"
]
},
"execution_count": 4,
Expand Down Expand Up @@ -177,15 +176,15 @@
"\u001b[1m> Entering new SQLDatabaseChain chain...\u001b[0m\n",
"How many employees are there in the foobar table? \n",
"SQLQuery:\u001b[32;1m\u001b[1;3m SELECT COUNT(*) FROM Employee;\u001b[0m\n",
"SQLResult: \u001b[33;1m\u001b[1;3m[(9,)]\u001b[0m\n",
"Answer:\u001b[32;1m\u001b[1;3m There are 9 employees in the foobar table.\u001b[0m\n",
"SQLResult: \u001b[33;1m\u001b[1;3m[(8,)]\u001b[0m\n",
"Answer:\u001b[32;1m\u001b[1;3m There are 8 employees in the foobar table.\u001b[0m\n",
"\u001b[1m> Finished chain.\u001b[0m\n"
]
},
{
"data": {
"text/plain": [
"' There are 9 employees in the foobar table.'"
"' There are 8 employees in the foobar table.'"
]
},
"execution_count": 7,
Expand Down Expand Up @@ -219,7 +218,7 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 9,
"id": "78b6af4d",
"metadata": {},
"outputs": [
Expand All @@ -232,18 +231,18 @@
"\u001b[1m> Entering new SQLDatabaseChain chain...\u001b[0m\n",
"How many employees are there in the foobar table? \n",
"SQLQuery:\u001b[32;1m\u001b[1;3m SELECT COUNT(*) FROM Employee;\u001b[0m\n",
"SQLResult: \u001b[33;1m\u001b[1;3m[(9,)]\u001b[0m\n",
"Answer:\u001b[32;1m\u001b[1;3m There are 9 employees in the foobar table.\u001b[0m\n",
"SQLResult: \u001b[33;1m\u001b[1;3m[(8,)]\u001b[0m\n",
"Answer:\u001b[32;1m\u001b[1;3m There are 8 employees in the foobar table.\u001b[0m\n",
"\u001b[1m> Finished chain.\u001b[0m\n"
]
},
{
"data": {
"text/plain": [
"[' SELECT COUNT(*) FROM Employee;', '[(9,)]']"
"[' SELECT COUNT(*) FROM Employee;', '[(8,)]']"
]
},
"execution_count": 10,
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -264,7 +263,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 10,
"id": "6adaa799",
"metadata": {},
"outputs": [],
Expand All @@ -274,7 +273,7 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 11,
"id": "edfc8a8e",
"metadata": {},
"outputs": [
Expand All @@ -286,8 +285,8 @@
"\n",
"\u001b[1m> Entering new SQLDatabaseChain chain...\u001b[0m\n",
"What are some example tracks by composer Johann Sebastian Bach? \n",
"SQLQuery:\u001b[32;1m\u001b[1;3m SELECT Name FROM Track WHERE Composer = 'Johann Sebastian Bach' LIMIT 3;\u001b[0m\n",
"SQLResult: \u001b[33;1m\u001b[1;3m[('Concerto for 2 Violins in D Minor, BWV 1043: I. Vivace',), ('Aria Mit 30 Veränderungen, BWV 988 \"Goldberg Variations\": Aria',), ('Suite for Solo Cello No. 1 in G Major, BWV 1007: I. Prélude',)]\u001b[0m\n",
"SQLQuery:\u001b[32;1m\u001b[1;3m SELECT Name, Composer FROM Track WHERE Composer = 'Johann Sebastian Bach' LIMIT 3;\u001b[0m\n",
"SQLResult: \u001b[33;1m\u001b[1;3m[('Concerto for 2 Violins in D Minor, BWV 1043: I. Vivace', 'Johann Sebastian Bach'), ('Aria Mit 30 Veränderungen, BWV 988 \"Goldberg Variations\": Aria', 'Johann Sebastian Bach'), ('Suite for Solo Cello No. 1 in G Major, BWV 1007: I. Prélude', 'Johann Sebastian Bach')]\u001b[0m\n",
"Answer:\u001b[32;1m\u001b[1;3m Examples of tracks by Johann Sebastian Bach include 'Concerto for 2 Violins in D Minor, BWV 1043: I. Vivace', 'Aria Mit 30 Veränderungen, BWV 988 \"Goldberg Variations\": Aria', and 'Suite for Solo Cello No. 1 in G Major, BWV 1007: I. Prélude'.\u001b[0m\n",
"\u001b[1m> Finished chain.\u001b[0m\n"
]
Expand All @@ -298,7 +297,7 @@
"' Examples of tracks by Johann Sebastian Bach include \\'Concerto for 2 Violins in D Minor, BWV 1043: I. Vivace\\', \\'Aria Mit 30 Veränderungen, BWV 988 \"Goldberg Variations\": Aria\\', and \\'Suite for Solo Cello No. 1 in G Major, BWV 1007: I. Prélude\\'.'"
]
},
"execution_count": 8,
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -312,26 +311,54 @@
"id": "bcc5e936",
"metadata": {},
"source": [
"## Adding first row of each table\n",
"Sometimes, the format of the data is not obvious and it is optimal to include the first row of the table in the prompt to allow the LLM to understand the data before providing a final query. Here we will use this feature to let the LLM know that artists are saved with their full names."
"## Adding example rows from each table\n",
"Sometimes, the format of the data is not obvious and it is optimal to include a sample of rows from the tables in the prompt to allow the LLM to understand the data before providing a final query. Here we will use this feature to let the LLM know that artists are saved with their full names by providing two rows from the `Track` table."
]
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 12,
"id": "9a22ee47",
"metadata": {},
"outputs": [],
"source": [
"db = SQLDatabase.from_uri(\n",
" \"sqlite:///../../../../notebooks/Chinook.db\", \n",
" include_tables=['Track'], # we include only one table to save tokens in the prompt :)\n",
" sample_row_in_table_info=True)"
" sample_rows_in_table_info=2)"
]
},
{
"cell_type": "markdown",
"id": "952c0b4d",
"metadata": {},
"source": [
"The sample rows are added to the prompt after each corresponding table's column information:"
]
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": 13,
"id": "9de86267",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Table 'Track' has columns: TrackId (INTEGER), Name (NVARCHAR(200)), AlbumId (INTEGER), MediaTypeId (INTEGER), GenreId (INTEGER), Composer (NVARCHAR(220)), Milliseconds (INTEGER), Bytes (INTEGER), UnitPrice (NUMERIC(10, 2)). Here is an example of 2 rows from this table (long strings are truncated):\n",
"1 For Those About To Rock (We Salute You) 1 1 1 Angus Young, Malcolm Young, Brian Johnson 343719 11170334 0.99\n",
"2 Balls to the Wall 2 2 1 None 342562 5510424 0.99\n"
]
}
],
"source": [
"print(db.table_info)"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "bcb7a489",
"metadata": {},
"outputs": [],
Expand All @@ -341,7 +368,7 @@
},
{
"cell_type": "code",
"execution_count": 13,
"execution_count": 15,
"id": "81e05d82",
"metadata": {},
"outputs": [
Expand All @@ -353,20 +380,19 @@
"\n",
"\u001b[1m> Entering new SQLDatabaseChain chain...\u001b[0m\n",
"What are some example tracks by Bach? \n",
"SQLQuery:Table 'Track' has columns: TrackId (INTEGER), Name (NVARCHAR(200)), AlbumId (INTEGER), MediaTypeId (INTEGER), GenreId (INTEGER), Composer (NVARCHAR(220)), Milliseconds (INTEGER), Bytes (INTEGER), UnitPrice (NUMERIC(10, 2)). Here is an example row for this table (long strings are truncated): ['1', 'For Those About To Rock (We Salute You)', '1', '1', '1', 'Angus Young, Malcolm Young, Brian Johnson', '343719', '11170334', '0.99'].\n",
"\u001b[32;1m\u001b[1;3m SELECT TrackId, Name, Composer FROM Track WHERE Composer LIKE '%Bach%' ORDER BY Name LIMIT 5;\u001b[0m\n",
"SQLResult: \u001b[33;1m\u001b[1;3m[(1709, 'American Woman', 'B. Cummings/G. Peterson/M.J. Kale/R. Bachman'), (3408, 'Aria Mit 30 Veränderungen, BWV 988 \"Goldberg Variations\": Aria', 'Johann Sebastian Bach'), (3433, 'Concerto No.2 in F Major, BWV1047, I. Allegro', 'Johann Sebastian Bach'), (3407, 'Concerto for 2 Violins in D Minor, BWV 1043: I. Vivace', 'Johann Sebastian Bach'), (3490, 'Partita in E Major, BWV 1006A: I. Prelude', 'Johann Sebastian Bach')]\u001b[0m\n",
"Answer:\u001b[32;1m\u001b[1;3m Some example tracks by Bach are 'American Woman', 'Aria Mit 30 Veränderungen, BWV 988 \"Goldberg Variations\": Aria', 'Concerto No.2 in F Major, BWV1047, I. Allegro', 'Concerto for 2 Violins in D Minor, BWV 1043: I. Vivace', and 'Partita in E Major, BWV 1006A: I. Prelude'.\u001b[0m\n",
"SQLQuery:\u001b[32;1m\u001b[1;3m SELECT Name, Composer FROM Track WHERE Composer LIKE '%Bach%' LIMIT 5;\u001b[0m\n",
"SQLResult: \u001b[33;1m\u001b[1;3m[('American Woman', 'B. Cummings/G. Peterson/M.J. Kale/R. Bachman'), ('Concerto for 2 Violins in D Minor, BWV 1043: I. Vivace', 'Johann Sebastian Bach'), ('Aria Mit 30 Veränderungen, BWV 988 \"Goldberg Variations\": Aria', 'Johann Sebastian Bach'), ('Suite for Solo Cello No. 1 in G Major, BWV 1007: I. Prélude', 'Johann Sebastian Bach'), ('Toccata and Fugue in D Minor, BWV 565: I. Toccata', 'Johann Sebastian Bach')]\u001b[0m\n",
"Answer:\u001b[32;1m\u001b[1;3m Some example tracks by Bach are 'American Woman', 'Concerto for 2 Violins in D Minor, BWV 1043: I. Vivace', 'Aria Mit 30 Veränderungen, BWV 988 \"Goldberg Variations\": Aria', 'Suite for Solo Cello No. 1 in G Major, BWV 1007: I. Prélude', and 'Toccata and Fugue in D Minor, BWV 565: I. Toccata'.\u001b[0m\n",
"\u001b[1m> Finished chain.\u001b[0m\n"
]
},
{
"data": {
"text/plain": [
"' Some example tracks by Bach are \\'American Woman\\', \\'Aria Mit 30 Veränderungen, BWV 988 \"Goldberg Variations\": Aria\\', \\'Concerto No.2 in F Major, BWV1047, I. Allegro\\', \\'Concerto for 2 Violins in D Minor, BWV 1043: I. Vivace\\', and \\'Partita in E Major, BWV 1006A: I. Prelude\\'.'"
"' Some example tracks by Bach are \\'American Woman\\', \\'Concerto for 2 Violins in D Minor, BWV 1043: I. Vivace\\', \\'Aria Mit 30 Veränderungen, BWV 988 \"Goldberg Variations\": Aria\\', \\'Suite for Solo Cello No. 1 in G Major, BWV 1007: I. Prélude\\', and \\'Toccata and Fugue in D Minor, BWV 565: I. Toccata\\'.'"
]
},
"execution_count": 13,
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
Expand Down Expand Up @@ -455,6 +481,10 @@
}
],
"metadata": {
"@webio": {
"lastCommId": null,
"lastKernelId": null
},
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
Expand All @@ -470,7 +500,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.9"
"version": "3.9.2"
}
},
"nbformat": 4,
Expand Down
43 changes: 34 additions & 9 deletions langchain/sql_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,16 @@ def __init__(
schema: Optional[str] = None,
ignore_tables: Optional[List[str]] = None,
include_tables: Optional[List[str]] = None,
sample_rows_in_table_info: int = 0,
# TODO: deprecate.
sample_row_in_table_info: bool = False,
):
"""Create engine from database URI."""
if sample_row_in_table_info and sample_rows_in_table_info > 0:
raise ValueError(
"Only one of `sample_row_in_table_info` "
"and `sample_rows_in_table_info` should be set"
)
self._engine = engine
self._schema = schema
if include_tables and ignore_tables:
Expand All @@ -40,7 +47,10 @@ def __init__(
raise ValueError(
f"ignore_tables {missing_tables} not found in database"
)
self._sample_row_in_table_info = sample_row_in_table_info
self._sample_rows_in_table_info = sample_rows_in_table_info
# TODO: deprecate
if sample_row_in_table_info:
self._sample_rows_in_table_info = 1

@classmethod
def from_uri(cls, database_uri: str, **kwargs: Any) -> SQLDatabase:
Expand All @@ -64,7 +74,12 @@ def table_info(self) -> str:
return self.get_table_info()

def get_table_info(self, table_names: Optional[List[str]] = None) -> str:
"""Get information about specified tables."""
"""Get information about specified tables.

If `sample_rows_in_table_info`, the specified number of sample rows will be
appended to each table description. This can increase performance as
demonstrated by Rajkumar et al, 2022 (https://arxiv.org/abs/2204.00498).
"""
all_table_names = self.get_table_names()
if table_names is not None:
missing_tables = set(table_names).difference(all_table_names)
Expand All @@ -83,15 +98,25 @@ def get_table_info(self, table_names: Optional[List[str]] = None) -> str:
column_str = ", ".join(columns)
table_str = template.format(table_name=table_name, columns=column_str)

if self._sample_row_in_table_info:
if self._sample_rows_in_table_info:
row_template = (
" Here is an example row for this table"
" (long strings are truncated): {sample_row}."
" Here is an example of {n_rows} rows from this table "
"(long strings are truncated):\n"
"{sample_rows}"
)
sample_rows = self.run(
f"SELECT * FROM '{table_name}' LIMIT "
f"{self._sample_rows_in_table_info}"
)
sample_row = self.run(f"SELECT * FROM '{table_name}' LIMIT 1")
if len(eval(sample_row)) > 0:
sample_row = " ".join([str(i)[:100] for i in eval(sample_row)[0]])
table_str += row_template.format(sample_row=sample_row)
sample_rows = eval(sample_rows)
if len(sample_rows) > 0:
n_rows = len(sample_rows)
sample_rows = "\n".join(
[" ".join([str(i)[:100] for i in row]) for row in sample_rows]
)
table_str += row_template.format(
n_rows=n_rows, sample_rows=sample_rows
)

tables.append(table_str)
return "\n".join(tables)
Expand Down
14 changes: 9 additions & 5 deletions tests/unit_tests/test_sql_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,23 +35,27 @@ def test_table_info() -> None:
assert sorted(output.split("\n")) == sorted(expected_output)


def test_table_info_w_sample_row() -> None:
def test_table_info_w_sample_rows() -> None:
"""Test that table info is constructed properly."""
engine = create_engine("sqlite:///:memory:")
metadata_obj.create_all(engine)
stmt = insert(user).values(user_id=13, user_name="Harrison")
values = [
{"user_id": 13, "user_name": "Harrison"},
{"user_id": 14, "user_name": "Chase"},
]
stmt = insert(user).values(values)
with engine.begin() as conn:
conn.execute(stmt)

db = SQLDatabase(engine, sample_row_in_table_info=True)
db = SQLDatabase(engine, sample_rows_in_table_info=2)

output = db.table_info
expected_output = (
"Table 'company' has columns: company_id (INTEGER), "
"company_location (VARCHAR).\n"
"Table 'user' has columns: user_id (INTEGER), "
"user_name (VARCHAR(16)). Here is an example row "
"for this table (long strings are truncated): 13 Harrison."
"user_name (VARCHAR(16)). Here is an example of 2 rows "
"from this table (long strings are truncated):\n13 Harrison\n14 Chase"
)
assert sorted(output.split("\n")) == sorted(expected_output.split("\n"))

Expand Down