Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 6 additions & 15 deletions src/sql_mock/table_mocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import sqlglot
from jinja2 import Template
from pydantic import BaseModel, ConfigDict, SkipValidation
from sqlglot.expressions import replace_tables

from sql_mock.column_mocks import ColumnMock
from sql_mock.constants import NO_INPUT
Expand All @@ -15,18 +14,18 @@ def get_keys_from_list_of_dicts(data: list[dict]) -> set[str]:
return set(key for dictionary in data for key in dictionary.keys())


def replace_original_table_references(query: str, mock_tables: list["BaseMockTable"], dialect: str = None):
def replace_original_table_references(query: str, mock_tables: list["BaseMockTable"]):
"""
Replace orignal table references to point them to the mocked data

Args:
query (str): Original SQL query
mock_tables (list[BaseMockTable]): List of BaseMockTable instances that are used as input
"""
ast = sqlglot.parse_one(query)
mapping = {mock_table._sql_mock_meta.table_ref: mock_table.cte_name for mock_table in mock_tables}
res = replace_tables(expression=ast, mapping=mapping, dialect=dialect).sql()
return res
for mock_table in mock_tables:
new_reference = mock_table._sql_mock_meta.table_ref.replace(".", "__")
query = query.replace(mock_table._sql_mock_meta.table_ref, new_reference)
return query


def select_from_cte(query: str, cte_name: str, sql_dialect: str):
Expand Down Expand Up @@ -211,12 +210,6 @@ def from_mocks(

return instance

@property
def cte_name(self):
mock_meta = self._sql_mock_meta
if getattr(mock_meta, "table_ref", None):
return self._sql_mock_meta.table_ref.replace(".", "__")

def _generate_input_data_cte_snippet(self):
# Convert instances into SQL snippets that serve as input to a CTE
table_ctes = [mock_table.as_sql_input() for mock_table in self._sql_mock_data.input_data]
Expand Down Expand Up @@ -256,9 +249,7 @@ def _generate_query(
final_columns_to_select=indent(final_columns_to_select, "\t"),
)

query = replace_original_table_references(
query, mock_tables=self._sql_mock_data.input_data, dialect=self._sql_dialect
)
query = replace_original_table_references(query, mock_tables=self._sql_mock_data.input_data)
# Store last query for debugging
self._sql_mock_data.last_query = query
return query
Expand Down
7 changes: 3 additions & 4 deletions tests/test_table_mocks/test_generate_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,7 @@ def test_replace_original_table_references_when_reference_exists():
"""...then the original table reference should be replaced with the mocked table reference"""
query = f"SELECT * FROM {MockTestTable._sql_mock_meta.table_ref}"
mock_tables = [MockTestTable()]
# Note that sqlglot will add a comment with the original table name at the end
expected = "SELECT * FROM data__mock_test_table /* data.mock_test_table */"
expected = "SELECT * FROM data__mock_test_table"
assert expected == replace_original_table_references(query, mock_tables)


Expand Down Expand Up @@ -136,7 +135,7 @@ def test_generate_query_no_cte_provided(mocker):
# Asserts
mocked_select_from_cte.assert_not_called()
mocked_replace_original_table_references.assert_called_once_with(
expected_query_template_result, mock_tables=[mock_table_instance], dialect=mock_table_instance._sql_dialect
expected_query_template_result, mock_tables=[mock_table_instance]
)
# The final query should be equal to whatever is returned by `replace_original_table_references`
assert query == mocked_replace_original_table_references.return_value
Expand Down Expand Up @@ -182,7 +181,7 @@ def test_generate_query_cte_provided(mocker):
original_query, cte_to_select, sql_dialect=MockTestTable._sql_dialect
)
mocked_replace_original_table_references.assert_called_once_with(
expected_query_template_result, mock_tables=[mock_table_instance], dialect=mock_table_instance._sql_dialect
expected_query_template_result, mock_tables=[mock_table_instance]
)
# The final query should be equal to whatever is returned by `replace_original_table_references`
assert query == mocked_replace_original_table_references.return_value