22from typing import TYPE_CHECKING , List
33
44import sqlglot
5- from sqlglot import exp
5+ from sqlglot . optimizer . eliminate_ctes import eliminate_ctes
66from sqlglot .expressions import replace_tables
77from sqlglot .optimizer .scope import build_scope
88
1616def get_keys_from_list_of_dicts (data : list [dict ]) -> set [str ]:
1717 return set (key for dictionary in data for key in dictionary .keys ())
1818
19+ def remove_cte_from_query (query_ast : sqlglot .Expression , cte_name : str ) -> sqlglot .Expression :
20+ """
21+ Remove a CTE from a query
22+
23+ Args:
24+ query_ast (sqlglot.Expression): The AST of the query
25+ cte_name (str): The name of the CTE to remove
26+ """
27+ for cte in query_ast .find_all (sqlglot .exp .CTE ):
28+ if cte .alias == cte_name :
29+ cte .pop ()
30+ return query_ast
1931
20- def replace_original_table_references (query : str , mock_tables : list [ "BaseTableMock" ] , dialect : str = None ) :
32+ def replace_original_table_references (query_ast : sqlglot . Expression , table_ref : str , sql_mock_cte_name : str , dialect : str ) -> sqlglot . Expression :
2133 """
22- Replace orignal table references to point them to the mocked data
34+ Replace original table reference with sql mock cte name to point them to the mocked data
2335
2436 Args:
25- query (str): Original SQL query
26- mock_tables (list[BaseTableMock]): List of BaseTableMock instances that are used as input
37+ query_ast (str): Original SQL query - parsed by sqlglot
38+ table_ref (str): Table ref to be replaced
39+ sql_mock_cte_name (str): Name of the CTE that will contain the mocked data
2740 dialect (str): The SQL dialect to use for parsing the query
2841 """
29- ast = sqlglot .parse_one (query , dialect = dialect )
30- mapping = {mock_table ._sql_mock_meta .table_ref : mock_table ._sql_mock_meta .cte_name for mock_table in mock_tables }
31- res = replace_tables (expression = ast , mapping = mapping , dialect = dialect ).sql (pretty = True , dialect = dialect )
32- return res
42+ return replace_tables (expression = query_ast , mapping = {table_ref : sql_mock_cte_name }, dialect = dialect )
3343
3444
3545def select_from_cte (query : str , cte_name : str , sql_dialect : str ):
@@ -66,7 +76,8 @@ def parse_table_refs(table_ref, dialect):
6676 return table_ref if not table_ref else str (sqlglot .parse_one (table_ref , dialect = dialect ))
6777
6878
69- def _strip_alias_transformer (node ):
79+ def _clean_table_ref_transformer (node ):
80+ node .comments = []
7081 node .set ("alias" , None )
7182 return node
7283
@@ -81,7 +92,7 @@ def get_source_tables(query, dialect) -> List[str]:
8192 root = build_scope (ast )
8293
8394 tables = {
84- str (source .transform (_strip_alias_transformer ))
95+ str (source .transform (_clean_table_ref_transformer ))
8596 # Traverse the Scope tree, not the AST
8697 for scope in root .traverse ()
8798 # `selected_sources` contains sources that have been selected in this scope, e.g. in a FROM or JOIN clause.
@@ -92,7 +103,7 @@ def get_source_tables(query, dialect) -> List[str]:
92103 # if the selected source is a table,
93104 # then `source` will be a Table instance.
94105 for alias , (node , source ) in scope .selected_sources .items ()
95- if isinstance (source , exp .Table )
106+ if isinstance (source , sqlglot . exp .Table )
96107 }
97108
98109 return list (tables )
@@ -125,7 +136,27 @@ def validate_input_mocks(input_mocks: List["BaseTableMock"]):
125136
126137
127138def validate_all_input_mocks_for_query_provided (query : str , input_mocks : List ["BaseTableMock" ], dialect : str ) -> None :
128- missing_source_table_mocks = get_source_tables (query = query , dialect = dialect )
139+ """
140+ Validate that all input mocks are provided for a query.
141+ Mocks can replace CTEs or tables in the query. If a CTE is replaced, upstream table references don't need to be provided anymore.
142+
143+ Args:
144+ query (str): The query to validate
145+ input_mocks (List[BaseTableMock]): The input mocks that are provided
146+ dialect (str): The SQL dialect to use for parsing the query
147+ """
148+ provided_table_refs = [mock_table ._sql_mock_meta .table_ref for mock_table in input_mocks if hasattr (mock_table ._sql_mock_meta , "table_ref" )]
149+ ast = sqlglot .parse_one (query , dialect = dialect )
150+
151+ # In case the table_ref is a CTE, we need to remove it from the query
152+ for table_ref in provided_table_refs :
153+ ast = remove_cte_from_query (query_ast = ast , cte_name = table_ref )
154+
155+ # Now we might have some superfluous CTEs that are not referenced anymore
156+ remaining_query = eliminate_ctes (ast ).sql (dialect = dialect , pretty = True )
157+
158+ # The remaining query should not contain raw table references anymore if everything is mocked correctly
159+ missing_source_table_mocks = get_source_tables (query = remaining_query , dialect = dialect )
129160 for mock_table in input_mocks :
130161 table_ref = getattr (mock_table ._sql_mock_meta , "table_ref" , None )
131162 # If the table exists as mock, we can remove it from missing source tables
0 commit comments