Skip to content

Commit 4fbd039

Browse files
authored
Allow TableMock to also mock CTEs (#39)
1 parent 52c9dc3 commit 4fbd039

File tree

5 files changed

+165
-40
lines changed

5 files changed

+165
-40
lines changed

src/sql_mock/helpers.py

Lines changed: 44 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from typing import TYPE_CHECKING, List
33

44
import sqlglot
5-
from sqlglot import exp
5+
from sqlglot.optimizer.eliminate_ctes import eliminate_ctes
66
from sqlglot.expressions import replace_tables
77
from sqlglot.optimizer.scope import build_scope
88

@@ -16,20 +16,30 @@
1616
def get_keys_from_list_of_dicts(data: list[dict]) -> set[str]:
1717
return set(key for dictionary in data for key in dictionary.keys())
1818

19+
def remove_cte_from_query(query_ast: sqlglot.Expression, cte_name: str) -> sqlglot.Expression:
20+
"""
21+
Remove a CTE from a query
22+
23+
Args:
24+
query_ast (sqlglot.Expression): The AST of the query
25+
cte_name (str): The name of the CTE to remove
26+
"""
27+
for cte in query_ast.find_all(sqlglot.exp.CTE):
28+
if cte.alias == cte_name:
29+
cte.pop()
30+
return query_ast
1931

20-
def replace_original_table_references(query: str, mock_tables: list["BaseTableMock"], dialect: str = None):
32+
def replace_original_table_references(query_ast: sqlglot.Expression, table_ref: str, sql_mock_cte_name: str, dialect: str) -> sqlglot.Expression:
2133
"""
22-
Replace orignal table references to point them to the mocked data
34+
Replace original table reference with sql mock cte name to point them to the mocked data
2335
2436
Args:
25-
query (str): Original SQL query
26-
mock_tables (list[BaseTableMock]): List of BaseTableMock instances that are used as input
37+
query_ast (str): Original SQL query - parsed by sqlglot
38+
table_ref (str): Table ref to be replaced
39+
sql_mock_cte_name (str): Name of the CTE that will contain the mocked data
2740
dialect (str): The SQL dialect to use for parsing the query
2841
"""
29-
ast = sqlglot.parse_one(query, dialect=dialect)
30-
mapping = {mock_table._sql_mock_meta.table_ref: mock_table._sql_mock_meta.cte_name for mock_table in mock_tables}
31-
res = replace_tables(expression=ast, mapping=mapping, dialect=dialect).sql(pretty=True, dialect=dialect)
32-
return res
42+
return replace_tables(expression=query_ast, mapping={table_ref: sql_mock_cte_name}, dialect=dialect)
3343

3444

3545
def select_from_cte(query: str, cte_name: str, sql_dialect: str):
@@ -66,7 +76,8 @@ def parse_table_refs(table_ref, dialect):
6676
return table_ref if not table_ref else str(sqlglot.parse_one(table_ref, dialect=dialect))
6777

6878

69-
def _strip_alias_transformer(node):
79+
def _clean_table_ref_transformer(node):
80+
node.comments = []
7081
node.set("alias", None)
7182
return node
7283

@@ -81,7 +92,7 @@ def get_source_tables(query, dialect) -> List[str]:
8192
root = build_scope(ast)
8293

8394
tables = {
84-
str(source.transform(_strip_alias_transformer))
95+
str(source.transform(_clean_table_ref_transformer))
8596
# Traverse the Scope tree, not the AST
8697
for scope in root.traverse()
8798
# `selected_sources` contains sources that have been selected in this scope, e.g. in a FROM or JOIN clause.
@@ -92,7 +103,7 @@ def get_source_tables(query, dialect) -> List[str]:
92103
# if the selected source is a table,
93104
# then `source` will be a Table instance.
94105
for alias, (node, source) in scope.selected_sources.items()
95-
if isinstance(source, exp.Table)
106+
if isinstance(source, sqlglot.exp.Table)
96107
}
97108

98109
return list(tables)
@@ -125,7 +136,27 @@ def validate_input_mocks(input_mocks: List["BaseTableMock"]):
125136

126137

127138
def validate_all_input_mocks_for_query_provided(query: str, input_mocks: List["BaseTableMock"], dialect: str) -> None:
128-
missing_source_table_mocks = get_source_tables(query=query, dialect=dialect)
139+
"""
140+
Validate that all input mocks are provided for a query.
141+
Mocks can replace CTEs or tables in the query. If a CTE is replaced, upstream table references don't need to be provided anymore.
142+
143+
Args:
144+
query (str): The query to validate
145+
input_mocks (List[BaseTableMock]): The input mocks that are provided
146+
dialect (str): The SQL dialect to use for parsing the query
147+
"""
148+
provided_table_refs = [mock_table._sql_mock_meta.table_ref for mock_table in input_mocks if hasattr(mock_table._sql_mock_meta, "table_ref")]
149+
ast = sqlglot.parse_one(query, dialect=dialect)
150+
151+
# In case the table_ref is a CTE, we need to remove it from the query
152+
for table_ref in provided_table_refs:
153+
ast = remove_cte_from_query(query_ast=ast, cte_name=table_ref)
154+
155+
# Now we might have some superfluous CTEs that are not referenced anymore
156+
remaining_query = eliminate_ctes(ast).sql(dialect=dialect, pretty=True)
157+
158+
# The remaining query should not contain raw table references anymore if everything is mocked correctly
159+
missing_source_table_mocks = get_source_tables(query=remaining_query, dialect=dialect)
129160
for mock_table in input_mocks:
130161
table_ref = getattr(mock_table._sql_mock_meta, "table_ref", None)
131162
# If the table exists as mock, we can remove it from missing source tables

src/sql_mock/table_mocks.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from typing import List, Type
33

44
import sqlglot
5+
from sqlglot.optimizer.eliminate_ctes import eliminate_ctes
56
from jinja2 import Template
67
from pydantic import BaseModel, ConfigDict, SkipValidation
78

@@ -11,6 +12,7 @@
1112
get_keys_from_list_of_dicts,
1213
parse_table_refs,
1314
replace_original_table_references,
15+
remove_cte_from_query,
1416
select_from_cte,
1517
validate_all_input_mocks_for_query_provided,
1618
validate_input_mocks,
@@ -72,7 +74,7 @@ class BaseTableMock:
7274
col1 = Int(default=1)
7375
7476
Attributes:
75-
_sql_mock_data (SQLMockData): A class that stores data which is for processing. This is automatcially created on instantiation.
77+
_sql_mock_data (SQLMockData): A class that stores data which is for processing. This is automatically created on instantiation.
7678
_sql_dialect (str): The sql dialect that the mock model uses. It will be leveraged by sqlglot.
7779
"""
7880

@@ -184,9 +186,13 @@ def _generate_query(
184186
final_columns_to_select=final_columns_to_select,
185187
)
186188

187-
query = replace_original_table_references(
188-
query, mock_tables=self._sql_mock_data.input_data, dialect=self._sql_dialect
189-
)
189+
query_ast = sqlglot.parse_one(query, dialect=self._sql_dialect)
190+
for mock_table in self._sql_mock_data.input_data:
191+
query_ast = mock_table.replace_original_references(query_ast=query_ast)
192+
193+
# Remove superfluous CTEs
194+
query_ast = eliminate_ctes(query_ast)
195+
query = query_ast.sql(pretty=True, dialect=self._sql_dialect)
190196
# Store last query for debugging
191197
self._sql_mock_data.last_query = query
192198
return query
@@ -235,6 +241,17 @@ def as_sql_input(self):
235241
snippet = indent(f"SELECT {snippet}", "\t")
236242
return f"{self._sql_mock_meta.cte_name} AS (\n{snippet}\n)"
237243

244+
def replace_original_references(self, query_ast: sqlglot.Expression) -> sqlglot.Expression:
245+
# In case we mock a CTE, we need to drop the original CTE from the query
246+
query_ast = remove_cte_from_query(query_ast=query_ast, cte_name=self._sql_mock_meta.table_ref)
247+
248+
return replace_original_table_references(
249+
query_ast=query_ast,
250+
table_ref=self._sql_mock_meta.table_ref,
251+
sql_mock_cte_name=self._sql_mock_meta.cte_name,
252+
dialect=self._sql_dialect,
253+
)
254+
238255
def _assert_equal(
239256
self,
240257
data: [dict],
@@ -343,4 +360,5 @@ class TableMockMeta(BaseModel):
343360
@property
344361
def cte_name(self):
345362
if getattr(self, "table_ref", None):
346-
return self.table_ref.replace('"', "").replace(".", "__").replace("-", "_")
363+
cleaned_ref = self.table_ref.replace('"', "").replace(".", "__").replace("-", "_")
364+
return f"sql_mock__{cleaned_ref}"

tests/sql_mock/test_helpers.py

Lines changed: 76 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -48,18 +48,26 @@ class MockTestTable(BaseTableMock):
4848
class TestReplaceOriginalTableReference:
4949
def test_replace_original_table_references_when_reference_exists(self):
5050
"""...then the original table reference should be replaced with the mocked table reference"""
51-
query = f"SELECT * FROM {MockTestTable._sql_mock_meta.table_ref}"
52-
mock_tables = [MockTestTable()]
51+
query_ast = sqlglot.parse_one(f"SELECT * FROM {MockTestTable._sql_mock_meta.table_ref}")
5352
# Note that sqlglot will add a comment with the original table name at the end
54-
expected = "SELECT\n *\nFROM data__mock_test_table /* data.mock_test_table */"
55-
assert expected == replace_original_table_references(query, mock_tables)
53+
expected = f"SELECT\n *\nFROM {MockTestTable._sql_mock_meta.cte_name} /* data.mock_test_table */"
54+
assert expected == replace_original_table_references(
55+
query_ast=query_ast,
56+
table_ref=MockTestTable._sql_mock_meta.table_ref,
57+
sql_mock_cte_name=MockTestTable._sql_mock_meta.cte_name,
58+
dialect="bigquery",
59+
).sql(pretty=True)
5660

5761
def test_replace_original_table_references_when_reference_does_not_exist(self):
5862
"""...then the original reference should not be replaced"""
59-
query = "SELECT * FROM some_table"
60-
mock_tables = [MockTestTable()]
63+
query_ast = sqlglot.parse_one("SELECT * FROM some_table")
6164
expected = "SELECT\n *\nFROM some_table"
62-
assert expected == replace_original_table_references(query, mock_tables)
65+
assert expected == replace_original_table_references(
66+
query_ast=query_ast,
67+
table_ref=MockTestTable._sql_mock_meta.table_ref,
68+
sql_mock_cte_name=MockTestTable._sql_mock_meta.cte_name,
69+
dialect="bigquery",
70+
).sql(pretty=True)
6371

6472

6573
class TestSelectFromCTE:
@@ -190,6 +198,52 @@ def test_input_mocks_not_in_query(self):
190198
== f"Your input mock {TableRefMock.__class__.__name__} is not a table that is referenced in the query"
191199
)
192200

201+
def test_input_mocks_missing_for_tables_within_mocked_cte(self):
202+
"""...then the validation should pass since the CTE would be mocked anyways"""
203+
query = """
204+
WITH cte_1 AS (
205+
SELECT * FROM some_table
206+
),
207+
208+
cte_2 AS (
209+
SELECT a, b
210+
FROM cte_1
211+
WHERE a = 'foo'
212+
)
213+
214+
SELECT a, b, * FROM cte_2
215+
"""
216+
@table_meta(table_ref="cte_1")
217+
class Cte1Mock(BaseTableMock):
218+
pass
219+
220+
validate_all_input_mocks_for_query_provided(
221+
query=query, input_mocks=[Cte1Mock()], dialect="bigquery"
222+
)
223+
224+
def test_cte_superfluous_after_mocking(self):
225+
"""...then the validation should pass since the CTE will be removed anyways and does not need to be mocked"""
226+
query = """
227+
WITH cte_1 AS (
228+
SELECT * FROM some_table
229+
),
230+
231+
cte_2 AS (
232+
SELECT a, b
233+
FROM cte_1
234+
WHERE a = 'foo'
235+
)
236+
237+
SELECT a, b, * FROM cte_2
238+
"""
239+
@table_meta(table_ref="cte_2") # This will make cte_1 superfluous
240+
class Cte1Mock(BaseTableMock):
241+
pass
242+
243+
validate_all_input_mocks_for_query_provided(
244+
query=query, input_mocks=[Cte1Mock()], dialect="bigquery"
245+
)
246+
193247

194248
class TestGetSourceTables:
195249
def test_query_with_ctes(self):
@@ -224,3 +278,18 @@ def test_query_with_multiple_references_of_single_table(self):
224278
expected = ["table_2"]
225279
assert len(res) == len(expected)
226280
assert all([val in expected for val in res])
281+
282+
def test_query_with_comment(self):
283+
"""...then the comment should be ignored"""
284+
285+
query = """
286+
SELECT
287+
1,
288+
2
289+
FROM table_1 /* some comment */
290+
"""
291+
292+
res = get_source_tables(query, dialect="bigquery")
293+
294+
expected = ["table_1"]
295+
assert res == expected

tests/sql_mock/test_table_mocks.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ def test_as_sql_input():
110110
]
111111
sql_input = mock_table_instance.as_sql_input()
112112
expected = (
113-
f"{mock_table_instance._sql_mock_meta.table_ref} AS (\n"
113+
f"{mock_table_instance._sql_mock_meta.cte_name} AS (\n"
114114
"\tSELECT cast('1' AS Integer) AS col1, cast('value1' AS String) AS col2\n"
115115
"\tUNION ALL\n"
116116
"\tSELECT cast('2' AS Integer) AS col1, cast('value2' AS String) AS col2\n"
@@ -147,7 +147,7 @@ def test_to_sql_model_no_data_provided(self):
147147
sql_model = mock_table.as_sql_input()
148148

149149
expected_sql_model = (
150-
"mock_test_table AS (\n"
150+
f"{mock_table._sql_mock_meta.cte_name} AS (\n"
151151
"\tSELECT cast('1' AS Integer) AS col1, cast('hey' AS String) AS col2 FROM (SELECT 1) WHERE FALSE\n"
152152
")"
153153
)
@@ -160,7 +160,7 @@ def test_to_sql_model_single_row_provided(self):
160160
sql_model = mock_table.as_sql_input()
161161

162162
expected_sql_model = (
163-
"mock_test_table AS (\n"
163+
f"{mock_table._sql_mock_meta.cte_name} AS (\n"
164164
"\tSELECT cast('42' AS Integer) AS col1, cast('test_value' AS String) AS col2\n"
165165
")"
166166
)
@@ -173,7 +173,7 @@ def test_to_sql_model_multiple_provided(self):
173173
sql_model = mock_table.as_sql_input()
174174

175175
expected_sql_model = (
176-
"mock_test_table AS (\n"
176+
f"{mock_table._sql_mock_meta.cte_name} AS (\n"
177177
"\tSELECT cast('42' AS Integer) AS col1, cast('test_value' AS String) AS col2\n"
178178
"\tUNION ALL\n"
179179
"\tSELECT cast('100' AS Integer) AS col1, cast('another_value' AS String) AS col2\n"
@@ -185,6 +185,6 @@ def test_to_sql_model_multiple_provided(self):
185185
def test_cte_name():
186186
mock_table_meta = TableMockMeta(table_ref='"my-project.schema.table_name"')
187187

188-
expected_cte_name = "my_project__schema__table_name"
188+
expected_cte_name = "sql_mock__my_project__schema__table_name"
189189

190190
assert mock_table_meta.cte_name == expected_cte_name

0 commit comments

Comments
 (0)