Skip to content

Commit

Permalink
refactor joins (#932)
Browse files Browse the repository at this point in the history
  • Loading branch information
ahuang11 authored Jan 10, 2025
1 parent 6b48e35 commit a38478a
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 17 deletions.
26 changes: 13 additions & 13 deletions lumen/ai/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,11 +459,9 @@ class SQLAgent(LumenBaseAgent):

_output_type = SQLOutput

async def _select_relevant_table(self, messages: list[Message]) -> tuple[str, BaseSQLSource, bool]:
async def _select_relevant_table(self, messages: list[Message], sources: list, tables_to_source: dict, tables_schema_str: str) -> tuple[str, BaseSQLSource, bool]:
"""Select the most relevant table based on the user query."""
join_required = None
sources = self._memory["sources"]
tables_to_source, tables_schema_str = await gather_table_sources(sources)
tables = tuple(tables_to_source)

user_message = ""
Expand Down Expand Up @@ -605,16 +603,18 @@ async def _create_valid_sql(
async def _check_requires_joins(
self,
messages: list[Message],
schema,
table: str
table: str,
schema: str,
tables_schema_str: str,
):
requires_joins = None
with self.interface.add_step(title="Checking if join is required", steps_layout=self._steps_layout) as step:
join_prompt = await self._render_prompt(
"require_joins",
messages,
schema=yaml.dump(schema),
table=table
table=table,
schema=schema, # this contains current table schema
tables_schema_str=tables_schema_str # this may not be populated with any schemas
)
model_spec = self.prompts["require_joins"].get("llm_spec", "default")
response = self.llm.stream(
Expand Down Expand Up @@ -643,7 +643,7 @@ async def find_join_tables(self, messages: list[Message]):
else:
tables = self._memory['source'].get_tables()

find_joins_prompt = await self._render_prompt("find_joins", messages, tables=tables)
find_joins_prompt = await self._render_prompt("find_joins", messages, tables=tables, separator=SOURCE_TABLE_SEPARATOR)
model_spec = self.prompts["find_joins"].get("llm_spec", "default")
with self.interface.add_step(title="Determining tables required for join", steps_layout=self._steps_layout) as step:
output = await self.llm.invoke(
Expand All @@ -662,7 +662,7 @@ async def find_join_tables(self, messages: list[Message]):
tables_to_source = {}
for source_table in tables_to_join:
sources = self._memory["sources"]
if multi_source:
if multi_source and SOURCE_TABLE_SEPARATOR in source_table:
try:
_, a_source_name, a_table = source_table.split(SOURCE_TABLE_SEPARATOR, maxsplit=2)
except ValueError:
Expand Down Expand Up @@ -701,16 +701,16 @@ async def respond(
8. If a join is required, remove source/table prefixes from the last message.
9. Construct the SQL query with `_create_valid_sql`.
"""
table, source, join_required = await self._select_relevant_table(messages)
sources = self._memory["sources"]
tables_to_source, tables_schema_str = await gather_table_sources(sources)
table, source, join_required = await self._select_relevant_table(messages, sources, tables_to_source, tables_schema_str)
if not hasattr(source, "get_sql_expr"):
return None

# include min max for more context for data cleaning
schema = await get_schema(source, table, include_min_max=True)

tables_to_source = {table: source}
if join_required is None:
join_required = await self._check_requires_joins(messages, schema, table)
join_required = await self._check_requires_joins(messages, table, schema, tables_schema_str)
if join_required is None:
# Bail if query was cancelled or errored out
return None
Expand Down
6 changes: 4 additions & 2 deletions lumen/ai/prompts/SQLAgent/find_joins.jinja2
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
{% extends 'Actor/main.jinja2' %}

{% block instructions %}
Correctly assess and list the tables that need to be joined; be sure to include both `//`.
Correctly assess and list the tables that need to be joined.

Use table names verbatim:

- if the table is read_csv('table.csv'), then use read_csv('table.csv') and not 'table' or 'table.csv'
- if the table is table.csv, then use table.csv and not read_csv('table.csv')

If there are delimiters, '{{ separator }}', be sure to include them:
'{{ separator }}source{{ separator }}table{{ separator }}''
{% endblock %}

{% block context %}
Expand Down
8 changes: 6 additions & 2 deletions lumen/ai/prompts/SQLAgent/require_joins.jinja2
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,18 @@ Determine whether a table join is required to answer the user's query.

Notes:
- Carefully consider the columns described in the schema
- If the schema references all the columns needed, no join is required
- Consider if missing data can be calculated without a join

A join is unnecessary if the schema already includes all the required columns, unless the schema values are untruncated and obviously missing necessary values in the min/max/enum.
{% endblock %}

{% block context %}
The current table '{{ table }}' follows this YAML schema:

The current table '{{ table }}' follows this YAML schema:
```yaml
{{ schema }}
```

Here are all the tables (and schemas if available):
{{ tables_schema_str }}
{%- endblock -%}

0 comments on commit a38478a

Please sign in to comment.