Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cosmos/operators/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,12 +444,12 @@ def _override_rtif_airflow_2_x(session: Session = NEW_SESSION) -> None:
assert ti.task is not None
ti.task.template_fields = self.template_fields
rtif = RenderedTaskInstanceFields(ti, render_templates=False)

# delete the old records
session.query(RenderedTaskInstanceFields).filter(
RenderedTaskInstanceFields.dag_id == self.dag_id, # type: ignore[attr-defined]
RenderedTaskInstanceFields.task_id == self.task_id,
RenderedTaskInstanceFields.run_id == ti.run_id,
RenderedTaskInstanceFields.map_index == ti.map_index,
).delete()
session.add(rtif)
else:
Expand Down
57 changes: 57 additions & 0 deletions tests/operators/test_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -2014,6 +2014,63 @@ def test_override_rtif_airflow3_with_should_store_compiled_sql_false():
assert operator.overwrite_rtif_after_execution is False


@pytest.mark.skipif(version.parse(airflow_version).major == 3, reason="Test only applies to Airflow 2")
def test_override_rtif_airflow2_filters_by_map_index():
"""Test that _override_rtif correctly filters by map_index for dynamically mapped tasks in Airflow 2.

This test ensures that when deleting old RenderedTaskInstanceFields records,
the map_index filter is applied so that only the current mapped task instance's
records are deleted, not records from other mapped task instances.

This addresses issue #2018 where SQL templated fields weren't properly rendered
for dynamically mapped tasks.
"""
from airflow.models.renderedtifields import RenderedTaskInstanceFields

with DAG("test_dag", start_date=datetime(2023, 1, 1)) as dag:
operator = DbtRunLocalOperator(
task_id="test",
profile_config=profile_config,
project_dir="my/dir",
should_store_compiled_sql=True,
dag=dag,
)

mock_ti = MagicMock(spec=TaskInstance)
mock_ti.dag_id = "test_dag"
mock_ti.task_id = "test"
mock_ti.run_id = "test_run_1"
mock_ti.map_index = 2 # Simulating a mapped task instance
mock_ti.task = operator

context = {"ti": mock_ti}

mock_session = MagicMock()
mock_query = MagicMock()
mock_filter = MagicMock()

mock_session.query.return_value = mock_query
mock_query.filter.return_value = mock_filter
mock_filter.delete.return_value = None

with patch("airflow.utils.session.provide_session") as mock_provide_session:

def session_decorator(func):
def wrapper(*args, **kwargs):
return func(session=mock_session)

return wrapper

mock_provide_session.side_effect = session_decorator

operator._override_rtif(context)

mock_session.query.assert_called_once_with(RenderedTaskInstanceFields)
mock_query.filter.assert_called_once()
filter_call_arg_map_index = str(mock_query.filter.call_args.args[-1])
assert filter_call_arg_map_index == "rendered_task_instance_fields.map_index = :map_index_1"


def test_dbt_cmd_flags_templating():
"""Test that dbt_cmd_flags supports Jinja templating."""
from datetime import datetime
Expand Down