Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 44 additions & 13 deletions src/sql_mock/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import TYPE_CHECKING, List

import sqlglot
from sqlglot import exp
from sqlglot.optimizer.eliminate_ctes import eliminate_ctes
from sqlglot.expressions import replace_tables
from sqlglot.optimizer.scope import build_scope

Expand All @@ -16,20 +16,30 @@
def get_keys_from_list_of_dicts(data: list[dict]) -> set[str]:
return set(key for dictionary in data for key in dictionary.keys())

def remove_cte_from_query(query_ast: sqlglot.Expression, cte_name: str) -> sqlglot.Expression:
"""
Remove a CTE from a query

Args:
query_ast (sqlglot.Expression): The AST of the query
cte_name (str): The name of the CTE to remove
"""
for cte in query_ast.find_all(sqlglot.exp.CTE):
if cte.alias == cte_name:
cte.pop()
return query_ast

def replace_original_table_references(query: str, mock_tables: list["BaseTableMock"], dialect: str = None):
def replace_original_table_references(query_ast: sqlglot.Expression, table_ref: str, sql_mock_cte_name: str, dialect: str) -> sqlglot.Expression:
"""
Replace orignal table references to point them to the mocked data
Replace original table reference with sql mock cte name to point them to the mocked data

Args:
query (str): Original SQL query
mock_tables (list[BaseTableMock]): List of BaseTableMock instances that are used as input
query_ast (str): Original SQL query - parsed by sqlglot
table_ref (str): Table ref to be replaced
sql_mock_cte_name (str): Name of the CTE that will contain the mocked data
dialect (str): The SQL dialect to use for parsing the query
"""
ast = sqlglot.parse_one(query, dialect=dialect)
mapping = {mock_table._sql_mock_meta.table_ref: mock_table._sql_mock_meta.cte_name for mock_table in mock_tables}
res = replace_tables(expression=ast, mapping=mapping, dialect=dialect).sql(pretty=True, dialect=dialect)
return res
return replace_tables(expression=query_ast, mapping={table_ref: sql_mock_cte_name}, dialect=dialect)


def select_from_cte(query: str, cte_name: str, sql_dialect: str):
Expand Down Expand Up @@ -66,7 +76,8 @@ def parse_table_refs(table_ref, dialect):
return table_ref if not table_ref else str(sqlglot.parse_one(table_ref, dialect=dialect))


def _strip_alias_transformer(node):
def _clean_table_ref_transformer(node):
node.comments = []
node.set("alias", None)
return node

Expand All @@ -81,7 +92,7 @@ def get_source_tables(query, dialect) -> List[str]:
root = build_scope(ast)

tables = {
str(source.transform(_strip_alias_transformer))
str(source.transform(_clean_table_ref_transformer))
# Traverse the Scope tree, not the AST
for scope in root.traverse()
# `selected_sources` contains sources that have been selected in this scope, e.g. in a FROM or JOIN clause.
Expand All @@ -92,7 +103,7 @@ def get_source_tables(query, dialect) -> List[str]:
# if the selected source is a table,
# then `source` will be a Table instance.
for alias, (node, source) in scope.selected_sources.items()
if isinstance(source, exp.Table)
if isinstance(source, sqlglot.exp.Table)
}

return list(tables)
Expand Down Expand Up @@ -125,7 +136,27 @@ def validate_input_mocks(input_mocks: List["BaseTableMock"]):


def validate_all_input_mocks_for_query_provided(query: str, input_mocks: List["BaseTableMock"], dialect: str) -> None:
missing_source_table_mocks = get_source_tables(query=query, dialect=dialect)
"""
Validate that all input mocks are provided for a query.
Mocks can replace CTEs or tables in the query. If a CTE is replaced, upstream table references don't need to be provided anymore.

Args:
query (str): The query to validate
input_mocks (List[BaseTableMock]): The input mocks that are provided
dialect (str): The SQL dialect to use for parsing the query
"""
provided_table_refs = [mock_table._sql_mock_meta.table_ref for mock_table in input_mocks if hasattr(mock_table._sql_mock_meta, "table_ref")]
ast = sqlglot.parse_one(query, dialect=dialect)

# In case the table_ref is a CTE, we need to remove it from the query
for table_ref in provided_table_refs:
ast = remove_cte_from_query(query_ast=ast, cte_name=table_ref)

# Now we might have some superfluous CTEs that are not referenced anymore
remaining_query = eliminate_ctes(ast).sql(dialect=dialect, pretty=True)

# The remaining query should not contain raw table references anymore if everything is mocked correctly
missing_source_table_mocks = get_source_tables(query=remaining_query, dialect=dialect)
for mock_table in input_mocks:
table_ref = getattr(mock_table._sql_mock_meta, "table_ref", None)
# If the table exists as mock, we can remove it from missing source tables
Expand Down
28 changes: 23 additions & 5 deletions src/sql_mock/table_mocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import List, Type

import sqlglot
from sqlglot.optimizer.eliminate_ctes import eliminate_ctes
from jinja2 import Template
from pydantic import BaseModel, ConfigDict, SkipValidation

Expand All @@ -11,6 +12,7 @@
get_keys_from_list_of_dicts,
parse_table_refs,
replace_original_table_references,
remove_cte_from_query,
select_from_cte,
validate_all_input_mocks_for_query_provided,
validate_input_mocks,
Expand Down Expand Up @@ -72,7 +74,7 @@ class BaseTableMock:
col1 = Int(default=1)

Attributes:
_sql_mock_data (SQLMockData): A class that stores data which is for processing. This is automatcially created on instantiation.
_sql_mock_data (SQLMockData): A class that stores data which is for processing. This is automatically created on instantiation.
_sql_dialect (str): The sql dialect that the mock model uses. It will be leveraged by sqlglot.
"""

Expand Down Expand Up @@ -184,9 +186,13 @@ def _generate_query(
final_columns_to_select=final_columns_to_select,
)

query = replace_original_table_references(
query, mock_tables=self._sql_mock_data.input_data, dialect=self._sql_dialect
)
query_ast = sqlglot.parse_one(query, dialect=self._sql_dialect)
for mock_table in self._sql_mock_data.input_data:
query_ast = mock_table.replace_original_references(query_ast=query_ast)

# Remove superfluous CTEs
query_ast = eliminate_ctes(query_ast)
query = query_ast.sql(pretty=True, dialect=self._sql_dialect)
# Store last query for debugging
self._sql_mock_data.last_query = query
return query
Expand Down Expand Up @@ -235,6 +241,17 @@ def as_sql_input(self):
snippet = indent(f"SELECT {snippet}", "\t")
return f"{self._sql_mock_meta.cte_name} AS (\n{snippet}\n)"

def replace_original_references(self, query_ast: sqlglot.Expression) -> sqlglot.Expression:
# In case we mock a CTE, we need to drop the original CTE from the query
query_ast = remove_cte_from_query(query_ast=query_ast, cte_name=self._sql_mock_meta.table_ref)

return replace_original_table_references(
query_ast=query_ast,
table_ref=self._sql_mock_meta.table_ref,
sql_mock_cte_name=self._sql_mock_meta.cte_name,
dialect=self._sql_dialect,
)

def _assert_equal(
self,
data: [dict],
Expand Down Expand Up @@ -343,4 +360,5 @@ class TableMockMeta(BaseModel):
@property
def cte_name(self):
if getattr(self, "table_ref", None):
return self.table_ref.replace('"', "").replace(".", "__").replace("-", "_")
cleaned_ref = self.table_ref.replace('"', "").replace(".", "__").replace("-", "_")
return f"sql_mock__{cleaned_ref}"
83 changes: 76 additions & 7 deletions tests/sql_mock/test_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,18 +48,26 @@ class MockTestTable(BaseTableMock):
class TestReplaceOriginalTableReference:
def test_replace_original_table_references_when_reference_exists(self):
"""...then the original table reference should be replaced with the mocked table reference"""
query = f"SELECT * FROM {MockTestTable._sql_mock_meta.table_ref}"
mock_tables = [MockTestTable()]
query_ast = sqlglot.parse_one(f"SELECT * FROM {MockTestTable._sql_mock_meta.table_ref}")
# Note that sqlglot will add a comment with the original table name at the end
expected = "SELECT\n *\nFROM data__mock_test_table /* data.mock_test_table */"
assert expected == replace_original_table_references(query, mock_tables)
expected = f"SELECT\n *\nFROM {MockTestTable._sql_mock_meta.cte_name} /* data.mock_test_table */"
assert expected == replace_original_table_references(
query_ast=query_ast,
table_ref=MockTestTable._sql_mock_meta.table_ref,
sql_mock_cte_name=MockTestTable._sql_mock_meta.cte_name,
dialect="bigquery",
).sql(pretty=True)

def test_replace_original_table_references_when_reference_does_not_exist(self):
"""...then the original reference should not be replaced"""
query = "SELECT * FROM some_table"
mock_tables = [MockTestTable()]
query_ast = sqlglot.parse_one("SELECT * FROM some_table")
expected = "SELECT\n *\nFROM some_table"
assert expected == replace_original_table_references(query, mock_tables)
assert expected == replace_original_table_references(
query_ast=query_ast,
table_ref=MockTestTable._sql_mock_meta.table_ref,
sql_mock_cte_name=MockTestTable._sql_mock_meta.cte_name,
dialect="bigquery",
).sql(pretty=True)


class TestSelectFromCTE:
Expand Down Expand Up @@ -190,6 +198,52 @@ def test_input_mocks_not_in_query(self):
== f"Your input mock {TableRefMock.__class__.__name__} is not a table that is referenced in the query"
)

def test_input_mocks_missing_for_tables_within_mocked_cte(self):
"""...then the validation should pass since the CTE would be mocked anyways"""
query = """
WITH cte_1 AS (
SELECT * FROM some_table
),

cte_2 AS (
SELECT a, b
FROM cte_1
WHERE a = 'foo'
)

SELECT a, b, * FROM cte_2
"""
@table_meta(table_ref="cte_1")
class Cte1Mock(BaseTableMock):
pass

validate_all_input_mocks_for_query_provided(
query=query, input_mocks=[Cte1Mock()], dialect="bigquery"
)

def test_cte_superfluous_after_mocking(self):
"""...then the validation should pass since the CTE will be removed anyways and does not need to be mocked"""
query = """
WITH cte_1 AS (
SELECT * FROM some_table
),

cte_2 AS (
SELECT a, b
FROM cte_1
WHERE a = 'foo'
)

SELECT a, b, * FROM cte_2
"""
@table_meta(table_ref="cte_2") # This will make cte_1 superfluous
class Cte1Mock(BaseTableMock):
pass

validate_all_input_mocks_for_query_provided(
query=query, input_mocks=[Cte1Mock()], dialect="bigquery"
)


class TestGetSourceTables:
def test_query_with_ctes(self):
Expand Down Expand Up @@ -224,3 +278,18 @@ def test_query_with_multiple_references_of_single_table(self):
expected = ["table_2"]
assert len(res) == len(expected)
assert all([val in expected for val in res])

def test_query_with_comment(self):
"""...then the comment should be ignored"""

query = """
SELECT
1,
2
FROM table_1 /* some comment */
"""

res = get_source_tables(query, dialect="bigquery")

expected = ["table_1"]
assert res == expected
10 changes: 5 additions & 5 deletions tests/sql_mock/test_table_mocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def test_as_sql_input():
]
sql_input = mock_table_instance.as_sql_input()
expected = (
f"{mock_table_instance._sql_mock_meta.table_ref} AS (\n"
f"{mock_table_instance._sql_mock_meta.cte_name} AS (\n"
"\tSELECT cast('1' AS Integer) AS col1, cast('value1' AS String) AS col2\n"
"\tUNION ALL\n"
"\tSELECT cast('2' AS Integer) AS col1, cast('value2' AS String) AS col2\n"
Expand Down Expand Up @@ -147,7 +147,7 @@ def test_to_sql_model_no_data_provided(self):
sql_model = mock_table.as_sql_input()

expected_sql_model = (
"mock_test_table AS (\n"
f"{mock_table._sql_mock_meta.cte_name} AS (\n"
"\tSELECT cast('1' AS Integer) AS col1, cast('hey' AS String) AS col2 FROM (SELECT 1) WHERE FALSE\n"
")"
)
Expand All @@ -160,7 +160,7 @@ def test_to_sql_model_single_row_provided(self):
sql_model = mock_table.as_sql_input()

expected_sql_model = (
"mock_test_table AS (\n"
f"{mock_table._sql_mock_meta.cte_name} AS (\n"
"\tSELECT cast('42' AS Integer) AS col1, cast('test_value' AS String) AS col2\n"
")"
)
Expand All @@ -173,7 +173,7 @@ def test_to_sql_model_multiple_provided(self):
sql_model = mock_table.as_sql_input()

expected_sql_model = (
"mock_test_table AS (\n"
f"{mock_table._sql_mock_meta.cte_name} AS (\n"
"\tSELECT cast('42' AS Integer) AS col1, cast('test_value' AS String) AS col2\n"
"\tUNION ALL\n"
"\tSELECT cast('100' AS Integer) AS col1, cast('another_value' AS String) AS col2\n"
Expand All @@ -185,6 +185,6 @@ def test_to_sql_model_multiple_provided(self):
def test_cte_name():
mock_table_meta = TableMockMeta(table_ref='"my-project.schema.table_name"')

expected_cte_name = "my_project__schema__table_name"
expected_cte_name = "sql_mock__my_project__schema__table_name"

assert mock_table_meta.cte_name == expected_cte_name
Loading