diff --git a/src/sql_mock/table_mocks.py b/src/sql_mock/table_mocks.py index 8abad7b..b2e1460 100644 --- a/src/sql_mock/table_mocks.py +++ b/src/sql_mock/table_mocks.py @@ -4,7 +4,6 @@ import sqlglot from jinja2 import Template from pydantic import BaseModel, ConfigDict, SkipValidation -from sqlglot.expressions import replace_tables from sql_mock.column_mocks import ColumnMock from sql_mock.constants import NO_INPUT @@ -15,7 +14,7 @@ def get_keys_from_list_of_dicts(data: list[dict]) -> set[str]: return set(key for dictionary in data for key in dictionary.keys()) -def replace_original_table_references(query: str, mock_tables: list["BaseMockTable"], dialect: str = None): +def replace_original_table_references(query: str, mock_tables: list["BaseMockTable"]): """ Replace orignal table references to point them to the mocked data @@ -23,10 +22,10 @@ def replace_original_table_references(query: str, mock_tables: list["BaseMockTab query (str): Original SQL query mock_tables (list[BaseMockTable]): List of BaseMockTable instances that are used as input """ - ast = sqlglot.parse_one(query) - mapping = {mock_table._sql_mock_meta.table_ref: mock_table.cte_name for mock_table in mock_tables} - res = replace_tables(expression=ast, mapping=mapping, dialect=dialect).sql() - return res + for mock_table in mock_tables: + new_reference = mock_table._sql_mock_meta.table_ref.replace(".", "__") + query = query.replace(mock_table._sql_mock_meta.table_ref, new_reference) + return query def select_from_cte(query: str, cte_name: str, sql_dialect: str): @@ -211,12 +210,6 @@ def from_mocks( return instance - @property - def cte_name(self): - mock_meta = self._sql_mock_meta - if getattr(mock_meta, "table_ref", None): - return self._sql_mock_meta.table_ref.replace(".", "__") - def _generate_input_data_cte_snippet(self): # Convert instances into SQL snippets that serve as input to a CTE table_ctes = [mock_table.as_sql_input() for mock_table in self._sql_mock_data.input_data] @@ -256,9 +249,7 @@ def _generate_query( final_columns_to_select=indent(final_columns_to_select, "\t"), ) - query = replace_original_table_references( - query, mock_tables=self._sql_mock_data.input_data, dialect=self._sql_dialect - ) + query = replace_original_table_references(query, mock_tables=self._sql_mock_data.input_data) # Store last query for debugging self._sql_mock_data.last_query = query return query diff --git a/tests/test_table_mocks/test_generate_query.py b/tests/test_table_mocks/test_generate_query.py index 0e31a05..1b299d5 100644 --- a/tests/test_table_mocks/test_generate_query.py +++ b/tests/test_table_mocks/test_generate_query.py @@ -30,8 +30,7 @@ def test_replace_original_table_references_when_reference_exists(): """...then the original table reference should be replaced with the mocked table reference""" query = f"SELECT * FROM {MockTestTable._sql_mock_meta.table_ref}" mock_tables = [MockTestTable()] - # Note that sqlglot will add a comment with the original table name at the end - expected = "SELECT * FROM data__mock_test_table /* data.mock_test_table */" + expected = "SELECT * FROM data__mock_test_table" assert expected == replace_original_table_references(query, mock_tables) @@ -136,7 +135,7 @@ def test_generate_query_no_cte_provided(mocker): # Asserts mocked_select_from_cte.assert_not_called() mocked_replace_original_table_references.assert_called_once_with( - expected_query_template_result, mock_tables=[mock_table_instance], dialect=mock_table_instance._sql_dialect + expected_query_template_result, mock_tables=[mock_table_instance] ) # The final query should be equal to whatever is returned by `replace_original_table_references` assert query == mocked_replace_original_table_references.return_value @@ -182,7 +181,7 @@ def test_generate_query_cte_provided(mocker): original_query, cte_to_select, sql_dialect=MockTestTable._sql_dialect ) mocked_replace_original_table_references.assert_called_once_with( - expected_query_template_result, mock_tables=[mock_table_instance], dialect=mock_table_instance._sql_dialect + expected_query_template_result, mock_tables=[mock_table_instance] ) # The final query should be equal to whatever is returned by `replace_original_table_references` assert query == mocked_replace_original_table_references.return_value