diff --git a/src/sql_mock/helpers.py b/src/sql_mock/helpers.py index 9208ba5..39fbac2 100644 --- a/src/sql_mock/helpers.py +++ b/src/sql_mock/helpers.py @@ -2,7 +2,7 @@ from typing import TYPE_CHECKING, List import sqlglot -from sqlglot import exp +from sqlglot.optimizer.eliminate_ctes import eliminate_ctes from sqlglot.expressions import replace_tables from sqlglot.optimizer.scope import build_scope @@ -16,20 +16,30 @@ def get_keys_from_list_of_dicts(data: list[dict]) -> set[str]: return set(key for dictionary in data for key in dictionary.keys()) +def remove_cte_from_query(query_ast: sqlglot.Expression, cte_name: str) -> sqlglot.Expression: + """ + Remove a CTE from a query + + Args: + query_ast (sqlglot.Expression): The AST of the query + cte_name (str): The name of the CTE to remove + """ + for cte in query_ast.find_all(sqlglot.exp.CTE): + if cte.alias == cte_name: + cte.pop() + return query_ast -def replace_original_table_references(query: str, mock_tables: list["BaseTableMock"], dialect: str = None): +def replace_original_table_references(query_ast: sqlglot.Expression, table_ref: str, sql_mock_cte_name: str, dialect: str) -> sqlglot.Expression: """ - Replace orignal table references to point them to the mocked data + Replace original table reference with sql mock cte name to point them to the mocked data Args: - query (str): Original SQL query - mock_tables (list[BaseTableMock]): List of BaseTableMock instances that are used as input + query_ast (str): Original SQL query - parsed by sqlglot + table_ref (str): Table ref to be replaced + sql_mock_cte_name (str): Name of the CTE that will contain the mocked data dialect (str): The SQL dialect to use for parsing the query """ - ast = sqlglot.parse_one(query, dialect=dialect) - mapping = {mock_table._sql_mock_meta.table_ref: mock_table._sql_mock_meta.cte_name for mock_table in mock_tables} - res = replace_tables(expression=ast, mapping=mapping, dialect=dialect).sql(pretty=True, dialect=dialect) - return res + return replace_tables(expression=query_ast, mapping={table_ref: sql_mock_cte_name}, dialect=dialect) def select_from_cte(query: str, cte_name: str, sql_dialect: str): @@ -66,7 +76,8 @@ def parse_table_refs(table_ref, dialect): return table_ref if not table_ref else str(sqlglot.parse_one(table_ref, dialect=dialect)) -def _strip_alias_transformer(node): +def _clean_table_ref_transformer(node): + node.comments = [] node.set("alias", None) return node @@ -81,7 +92,7 @@ def get_source_tables(query, dialect) -> List[str]: root = build_scope(ast) tables = { - str(source.transform(_strip_alias_transformer)) + str(source.transform(_clean_table_ref_transformer)) # Traverse the Scope tree, not the AST for scope in root.traverse() # `selected_sources` contains sources that have been selected in this scope, e.g. in a FROM or JOIN clause. @@ -92,7 +103,7 @@ def get_source_tables(query, dialect) -> List[str]: # if the selected source is a table, # then `source` will be a Table instance. for alias, (node, source) in scope.selected_sources.items() - if isinstance(source, exp.Table) + if isinstance(source, sqlglot.exp.Table) } return list(tables) @@ -125,7 +136,27 @@ def validate_input_mocks(input_mocks: List["BaseTableMock"]): def validate_all_input_mocks_for_query_provided(query: str, input_mocks: List["BaseTableMock"], dialect: str) -> None: - missing_source_table_mocks = get_source_tables(query=query, dialect=dialect) + """ + Validate that all input mocks are provided for a query. + Mocks can replace CTEs or tables in the query. If a CTE is replaced, upstream table references don't need to be provided anymore. + + Args: + query (str): The query to validate + input_mocks (List[BaseTableMock]): The input mocks that are provided + dialect (str): The SQL dialect to use for parsing the query + """ + provided_table_refs = [mock_table._sql_mock_meta.table_ref for mock_table in input_mocks if hasattr(mock_table._sql_mock_meta, "table_ref")] + ast = sqlglot.parse_one(query, dialect=dialect) + + # In case the table_ref is a CTE, we need to remove it from the query + for table_ref in provided_table_refs: + ast = remove_cte_from_query(query_ast=ast, cte_name=table_ref) + + # Now we might have some superfluous CTEs that are not referenced anymore + remaining_query = eliminate_ctes(ast).sql(dialect=dialect, pretty=True) + + # The remaining query should not contain raw table references anymore if everything is mocked correctly + missing_source_table_mocks = get_source_tables(query=remaining_query, dialect=dialect) for mock_table in input_mocks: table_ref = getattr(mock_table._sql_mock_meta, "table_ref", None) # If the table exists as mock, we can remove it from missing source tables diff --git a/src/sql_mock/table_mocks.py b/src/sql_mock/table_mocks.py index 3d5b99b..03d619b 100644 --- a/src/sql_mock/table_mocks.py +++ b/src/sql_mock/table_mocks.py @@ -2,6 +2,7 @@ from typing import List, Type import sqlglot +from sqlglot.optimizer.eliminate_ctes import eliminate_ctes from jinja2 import Template from pydantic import BaseModel, ConfigDict, SkipValidation @@ -11,6 +12,7 @@ get_keys_from_list_of_dicts, parse_table_refs, replace_original_table_references, + remove_cte_from_query, select_from_cte, validate_all_input_mocks_for_query_provided, validate_input_mocks, @@ -72,7 +74,7 @@ class BaseTableMock: col1 = Int(default=1) Attributes: - _sql_mock_data (SQLMockData): A class that stores data which is for processing. This is automatcially created on instantiation. + _sql_mock_data (SQLMockData): A class that stores data which is for processing. This is automatically created on instantiation. _sql_dialect (str): The sql dialect that the mock model uses. It will be leveraged by sqlglot. """ @@ -184,9 +186,13 @@ def _generate_query( final_columns_to_select=final_columns_to_select, ) - query = replace_original_table_references( - query, mock_tables=self._sql_mock_data.input_data, dialect=self._sql_dialect - ) + query_ast = sqlglot.parse_one(query, dialect=self._sql_dialect) + for mock_table in self._sql_mock_data.input_data: + query_ast = mock_table.replace_original_references(query_ast=query_ast) + + # Remove superfluous CTEs + query_ast = eliminate_ctes(query_ast) + query = query_ast.sql(pretty=True, dialect=self._sql_dialect) # Store last query for debugging self._sql_mock_data.last_query = query return query @@ -235,6 +241,17 @@ def as_sql_input(self): snippet = indent(f"SELECT {snippet}", "\t") return f"{self._sql_mock_meta.cte_name} AS (\n{snippet}\n)" + def replace_original_references(self, query_ast: sqlglot.Expression) -> sqlglot.Expression: + # In case we mock a CTE, we need to drop the original CTE from the query + query_ast = remove_cte_from_query(query_ast=query_ast, cte_name=self._sql_mock_meta.table_ref) + + return replace_original_table_references( + query_ast=query_ast, + table_ref=self._sql_mock_meta.table_ref, + sql_mock_cte_name=self._sql_mock_meta.cte_name, + dialect=self._sql_dialect, + ) + def _assert_equal( self, data: [dict], @@ -343,4 +360,5 @@ class TableMockMeta(BaseModel): @property def cte_name(self): if getattr(self, "table_ref", None): - return self.table_ref.replace('"', "").replace(".", "__").replace("-", "_") + cleaned_ref = self.table_ref.replace('"', "").replace(".", "__").replace("-", "_") + return f"sql_mock__{cleaned_ref}" diff --git a/tests/sql_mock/test_helpers.py b/tests/sql_mock/test_helpers.py index c991a5e..e5b2ff5 100644 --- a/tests/sql_mock/test_helpers.py +++ b/tests/sql_mock/test_helpers.py @@ -48,18 +48,26 @@ class MockTestTable(BaseTableMock): class TestReplaceOriginalTableReference: def test_replace_original_table_references_when_reference_exists(self): """...then the original table reference should be replaced with the mocked table reference""" - query = f"SELECT * FROM {MockTestTable._sql_mock_meta.table_ref}" - mock_tables = [MockTestTable()] + query_ast = sqlglot.parse_one(f"SELECT * FROM {MockTestTable._sql_mock_meta.table_ref}") # Note that sqlglot will add a comment with the original table name at the end - expected = "SELECT\n *\nFROM data__mock_test_table /* data.mock_test_table */" - assert expected == replace_original_table_references(query, mock_tables) + expected = f"SELECT\n *\nFROM {MockTestTable._sql_mock_meta.cte_name} /* data.mock_test_table */" + assert expected == replace_original_table_references( + query_ast=query_ast, + table_ref=MockTestTable._sql_mock_meta.table_ref, + sql_mock_cte_name=MockTestTable._sql_mock_meta.cte_name, + dialect="bigquery", + ).sql(pretty=True) def test_replace_original_table_references_when_reference_does_not_exist(self): """...then the original reference should not be replaced""" - query = "SELECT * FROM some_table" - mock_tables = [MockTestTable()] + query_ast = sqlglot.parse_one("SELECT * FROM some_table") expected = "SELECT\n *\nFROM some_table" - assert expected == replace_original_table_references(query, mock_tables) + assert expected == replace_original_table_references( + query_ast=query_ast, + table_ref=MockTestTable._sql_mock_meta.table_ref, + sql_mock_cte_name=MockTestTable._sql_mock_meta.cte_name, + dialect="bigquery", + ).sql(pretty=True) class TestSelectFromCTE: @@ -190,6 +198,52 @@ def test_input_mocks_not_in_query(self): == f"Your input mock {TableRefMock.__class__.__name__} is not a table that is referenced in the query" ) + def test_input_mocks_missing_for_tables_within_mocked_cte(self): + """...then the validation should pass since the CTE would be mocked anyways""" + query = """ + WITH cte_1 AS ( + SELECT * FROM some_table + ), + + cte_2 AS ( + SELECT a, b + FROM cte_1 + WHERE a = 'foo' + ) + + SELECT a, b, * FROM cte_2 + """ + @table_meta(table_ref="cte_1") + class Cte1Mock(BaseTableMock): + pass + + validate_all_input_mocks_for_query_provided( + query=query, input_mocks=[Cte1Mock()], dialect="bigquery" + ) + + def test_cte_superfluous_after_mocking(self): + """...then the validation should pass since the CTE will be removed anyways and does not need to be mocked""" + query = """ + WITH cte_1 AS ( + SELECT * FROM some_table + ), + + cte_2 AS ( + SELECT a, b + FROM cte_1 + WHERE a = 'foo' + ) + + SELECT a, b, * FROM cte_2 + """ + @table_meta(table_ref="cte_2") # This will make cte_1 superfluous + class Cte1Mock(BaseTableMock): + pass + + validate_all_input_mocks_for_query_provided( + query=query, input_mocks=[Cte1Mock()], dialect="bigquery" + ) + class TestGetSourceTables: def test_query_with_ctes(self): @@ -224,3 +278,18 @@ def test_query_with_multiple_references_of_single_table(self): expected = ["table_2"] assert len(res) == len(expected) assert all([val in expected for val in res]) + + def test_query_with_comment(self): + """...then the comment should be ignored""" + + query = """ + SELECT + 1, + 2 + FROM table_1 /* some comment */ + """ + + res = get_source_tables(query, dialect="bigquery") + + expected = ["table_1"] + assert res == expected diff --git a/tests/sql_mock/test_table_mocks.py b/tests/sql_mock/test_table_mocks.py index 7a29503..10b2abb 100644 --- a/tests/sql_mock/test_table_mocks.py +++ b/tests/sql_mock/test_table_mocks.py @@ -110,7 +110,7 @@ def test_as_sql_input(): ] sql_input = mock_table_instance.as_sql_input() expected = ( - f"{mock_table_instance._sql_mock_meta.table_ref} AS (\n" + f"{mock_table_instance._sql_mock_meta.cte_name} AS (\n" "\tSELECT cast('1' AS Integer) AS col1, cast('value1' AS String) AS col2\n" "\tUNION ALL\n" "\tSELECT cast('2' AS Integer) AS col1, cast('value2' AS String) AS col2\n" @@ -147,7 +147,7 @@ def test_to_sql_model_no_data_provided(self): sql_model = mock_table.as_sql_input() expected_sql_model = ( - "mock_test_table AS (\n" + f"{mock_table._sql_mock_meta.cte_name} AS (\n" "\tSELECT cast('1' AS Integer) AS col1, cast('hey' AS String) AS col2 FROM (SELECT 1) WHERE FALSE\n" ")" ) @@ -160,7 +160,7 @@ def test_to_sql_model_single_row_provided(self): sql_model = mock_table.as_sql_input() expected_sql_model = ( - "mock_test_table AS (\n" + f"{mock_table._sql_mock_meta.cte_name} AS (\n" "\tSELECT cast('42' AS Integer) AS col1, cast('test_value' AS String) AS col2\n" ")" ) @@ -173,7 +173,7 @@ def test_to_sql_model_multiple_provided(self): sql_model = mock_table.as_sql_input() expected_sql_model = ( - "mock_test_table AS (\n" + f"{mock_table._sql_mock_meta.cte_name} AS (\n" "\tSELECT cast('42' AS Integer) AS col1, cast('test_value' AS String) AS col2\n" "\tUNION ALL\n" "\tSELECT cast('100' AS Integer) AS col1, cast('another_value' AS String) AS col2\n" @@ -185,6 +185,6 @@ def test_to_sql_model_multiple_provided(self): def test_cte_name(): mock_table_meta = TableMockMeta(table_ref='"my-project.schema.table_name"') - expected_cte_name = "my_project__schema__table_name" + expected_cte_name = "sql_mock__my_project__schema__table_name" assert mock_table_meta.cte_name == expected_cte_name diff --git a/tests/test_table_mocks/test_generate_query.py b/tests/test_table_mocks/test_generate_query.py index 8e09245..8860321 100644 --- a/tests/test_table_mocks/test_generate_query.py +++ b/tests/test_table_mocks/test_generate_query.py @@ -1,4 +1,4 @@ -from textwrap import dedent +import sqlglot from sql_mock.column_mocks import BaseColumnMock from sql_mock.table_mocks import BaseTableMock, table_meta @@ -30,7 +30,7 @@ def test_generate_query_no_cte_provided(mocker): mock_table_instance = MockTestTable.from_dicts([]) mock_table_instance._sql_mock_data.input_data = [mock_table_instance] original_query = f"SELECT * FROM {mock_table_instance._sql_mock_meta.table_ref}" - dummy_return_query = "SELECT foo FROM bar" + dummy_return_query = sqlglot.parse_one("SELECT foo FROM bar") mock_table_instance._sql_mock_data.rendered_query = original_query mocked_select_from_cte = mocker.patch("sql_mock.table_mocks.select_from_cte") @@ -38,7 +38,7 @@ def test_generate_query_no_cte_provided(mocker): "sql_mock.table_mocks.replace_original_table_references", return_value=dummy_return_query ) - expected_query_template_result = dedent( + expected_query_template_result = sqlglot.parse_one( f""" WITH {mock_table_instance._sql_mock_meta.cte_name} AS ( \tSELECT cast('1' AS Integer) AS col1, cast('hey' AS String) AS col2 FROM (SELECT 1) WHERE FALSE @@ -61,10 +61,13 @@ def test_generate_query_no_cte_provided(mocker): # Asserts mocked_select_from_cte.assert_not_called() mocked_replace_original_table_references.assert_called_once_with( - expected_query_template_result, mock_tables=[mock_table_instance], dialect=mock_table_instance._sql_dialect + query_ast=expected_query_template_result, + table_ref=mock_table_instance._sql_mock_meta.table_ref, + sql_mock_cte_name=mock_table_instance._sql_mock_meta.cte_name, + dialect=mock_table_instance._sql_dialect ) # The final query should be equal to whatever is returned by `replace_original_table_references` - assert query == mocked_replace_original_table_references.return_value + assert query == mocked_replace_original_table_references.return_value.sql(pretty=True) def test_generate_query_cte_provided(mocker): @@ -75,7 +78,7 @@ def test_generate_query_cte_provided(mocker): original_query = f"SELECT * FROM {mock_table_instance._sql_mock_meta.table_ref}" cte_to_select = "some_cte" cte_adjusted_query = f"SELECT * FROM {cte_to_select}" - dummy_return_query = "SELECT foo FROM bar" + dummy_return_query = sqlglot.parse_one("SELECT foo FROM bar") mock_table_instance._sql_mock_data.rendered_query = original_query mocked_select_from_cte = mocker.patch("sql_mock.table_mocks.select_from_cte", return_value=cte_adjusted_query) @@ -83,7 +86,7 @@ def test_generate_query_cte_provided(mocker): "sql_mock.table_mocks.replace_original_table_references", return_value=dummy_return_query ) - expected_query_template_result = dedent( + expected_query_template_result = sqlglot.parse_one( f""" WITH {mock_table_instance._sql_mock_meta.cte_name} AS ( \tSELECT cast('1' AS Integer) AS col1, cast('hey' AS String) AS col2 FROM (SELECT 1) WHERE FALSE @@ -104,13 +107,17 @@ def test_generate_query_cte_provided(mocker): # Asserts mocked_select_from_cte.assert_called_once_with( - original_query, cte_to_select, sql_dialect=MockTestTable._sql_dialect + original_query, cte_to_select, sql_dialect=mock_table_instance._sql_dialect ) + # replace_original_table_references should be called once since we have a single input table mocked_replace_original_table_references.assert_called_once_with( - expected_query_template_result, mock_tables=[mock_table_instance], dialect=mock_table_instance._sql_dialect + query_ast=expected_query_template_result, + table_ref=mock_table_instance._sql_mock_meta.table_ref, + sql_mock_cte_name=mock_table_instance._sql_mock_meta.cte_name, + dialect=mock_table_instance._sql_dialect ) # The final query should be equal to whatever is returned by `replace_original_table_references` - assert query == mocked_replace_original_table_references.return_value + assert query == mocked_replace_original_table_references.return_value.sql(pretty=True) def test_generate_query_sql_has_semicolon():