diff --git a/lumen/ai/agents.py b/lumen/ai/agents.py index 68a54f17..23f54502 100644 --- a/lumen/ai/agents.py +++ b/lumen/ai/agents.py @@ -459,11 +459,9 @@ class SQLAgent(LumenBaseAgent): _output_type = SQLOutput - async def _select_relevant_table(self, messages: list[Message]) -> tuple[str, BaseSQLSource, bool]: + async def _select_relevant_table(self, messages: list[Message], sources: list, tables_to_source: dict, tables_schema_str: str) -> tuple[str, BaseSQLSource, bool]: """Select the most relevant table based on the user query.""" join_required = None - sources = self._memory["sources"] - tables_to_source, tables_schema_str = await gather_table_sources(sources) tables = tuple(tables_to_source) user_message = "" @@ -605,16 +603,18 @@ async def _create_valid_sql( async def _check_requires_joins( self, messages: list[Message], - schema, - table: str + table: str, + schema: str, + tables_schema_str: str, ): requires_joins = None with self.interface.add_step(title="Checking if join is required", steps_layout=self._steps_layout) as step: join_prompt = await self._render_prompt( "require_joins", messages, - schema=yaml.dump(schema), - table=table + table=table, + schema=schema, # this contains current table schema + tables_schema_str=tables_schema_str # this may not be populated with any schemas ) model_spec = self.prompts["require_joins"].get("llm_spec", "default") response = self.llm.stream( @@ -643,7 +643,7 @@ async def find_join_tables(self, messages: list[Message]): else: tables = self._memory['source'].get_tables() - find_joins_prompt = await self._render_prompt("find_joins", messages, tables=tables) + find_joins_prompt = await self._render_prompt("find_joins", messages, tables=tables, separator=SOURCE_TABLE_SEPARATOR) model_spec = self.prompts["find_joins"].get("llm_spec", "default") with self.interface.add_step(title="Determining tables required for join", steps_layout=self._steps_layout) as step: output = await self.llm.invoke( @@ -662,7 +662,7 @@ async def find_join_tables(self, messages: list[Message]): tables_to_source = {} for source_table in tables_to_join: sources = self._memory["sources"] - if multi_source: + if multi_source and SOURCE_TABLE_SEPARATOR in source_table: try: _, a_source_name, a_table = source_table.split(SOURCE_TABLE_SEPARATOR, maxsplit=2) except ValueError: @@ -701,16 +701,16 @@ async def respond( 8. If a join is required, remove source/table prefixes from the last message. 9. Construct the SQL query with `_create_valid_sql`. """ - table, source, join_required = await self._select_relevant_table(messages) + sources = self._memory["sources"] + tables_to_source, tables_schema_str = await gather_table_sources(sources) + table, source, join_required = await self._select_relevant_table(messages, sources, tables_to_source, tables_schema_str) if not hasattr(source, "get_sql_expr"): return None - # include min max for more context for data cleaning schema = await get_schema(source, table, include_min_max=True) - tables_to_source = {table: source} if join_required is None: - join_required = await self._check_requires_joins(messages, schema, table) + join_required = await self._check_requires_joins(messages, table, schema, tables_schema_str) if join_required is None: # Bail if query was cancelled or errored out return None diff --git a/lumen/ai/prompts/SQLAgent/find_joins.jinja2 b/lumen/ai/prompts/SQLAgent/find_joins.jinja2 index 1100aeb6..8ab7fc2d 100644 --- a/lumen/ai/prompts/SQLAgent/find_joins.jinja2 +++ b/lumen/ai/prompts/SQLAgent/find_joins.jinja2 @@ -1,12 +1,14 @@ {% extends 'Actor/main.jinja2' %} {% block instructions %} -Correctly assess and list the tables that need to be joined; be sure to include both `//`. +Correctly assess and list the tables that need to be joined. Use table names verbatim: - - if the table is read_csv('table.csv'), then use read_csv('table.csv') and not 'table' or 'table.csv' - if the table is table.csv, then use table.csv and not read_csv('table.csv') + +If there are delimiters, '{{ separator }}', be sure to include them: +'{{ separator }}source{{ separator }}table{{ separator }}'' {% endblock %} {% block context %} diff --git a/lumen/ai/prompts/SQLAgent/require_joins.jinja2 b/lumen/ai/prompts/SQLAgent/require_joins.jinja2 index 914c3a32..83bc73b9 100644 --- a/lumen/ai/prompts/SQLAgent/require_joins.jinja2 +++ b/lumen/ai/prompts/SQLAgent/require_joins.jinja2 @@ -5,14 +5,18 @@ Determine whether a table join is required to answer the user's query. Notes: - Carefully consider the columns described in the schema -- If the schema references all the columns needed, no join is required - Consider if missing data can be calculated without a join + +A join is unnecessary if the schema already includes all the required columns, unless the schema values are untruncated and obviously missing necessary values in the min/max/enum. {% endblock %} {% block context %} -The current table '{{ table }}' follows this YAML schema: +The current table '{{ table }}' follows this YAML schema: ```yaml {{ schema }} ``` + +Here are all the tables (and schemas if available): +{{ tables_schema_str }} {%- endblock -%}