diff --git a/cosmos/operators/local.py b/cosmos/operators/local.py index 2d0aa97744..7a69402aa3 100644 --- a/cosmos/operators/local.py +++ b/cosmos/operators/local.py @@ -444,12 +444,12 @@ def _override_rtif_airflow_2_x(session: Session = NEW_SESSION) -> None: assert ti.task is not None ti.task.template_fields = self.template_fields rtif = RenderedTaskInstanceFields(ti, render_templates=False) - # delete the old records session.query(RenderedTaskInstanceFields).filter( RenderedTaskInstanceFields.dag_id == self.dag_id, # type: ignore[attr-defined] RenderedTaskInstanceFields.task_id == self.task_id, RenderedTaskInstanceFields.run_id == ti.run_id, + RenderedTaskInstanceFields.map_index == ti.map_index, ).delete() session.add(rtif) else: diff --git a/tests/operators/test_local.py b/tests/operators/test_local.py index 6b536819ae..6783b67c6a 100644 --- a/tests/operators/test_local.py +++ b/tests/operators/test_local.py @@ -2014,6 +2014,63 @@ def test_override_rtif_airflow3_with_should_store_compiled_sql_false(): assert operator.overwrite_rtif_after_execution is False +@pytest.mark.skipif(version.parse(airflow_version).major == 3, reason="Test only applies to Airflow 2") +def test_override_rtif_airflow2_filters_by_map_index(): + """Test that _override_rtif correctly filters by map_index for dynamically mapped tasks in Airflow 2. + + This test ensures that when deleting old RenderedTaskInstanceFields records, + the map_index filter is applied so that only the current mapped task instance's + records are deleted, not records from other mapped task instances. + + This addresses issue #2018 where SQL templated fields weren't properly rendered + for dynamically mapped tasks. + """ + from airflow.models.renderedtifields import RenderedTaskInstanceFields + + with DAG("test_dag", start_date=datetime(2023, 1, 1)) as dag: + operator = DbtRunLocalOperator( + task_id="test", + profile_config=profile_config, + project_dir="my/dir", + should_store_compiled_sql=True, + dag=dag, + ) + + mock_ti = MagicMock(spec=TaskInstance) + mock_ti.dag_id = "test_dag" + mock_ti.task_id = "test" + mock_ti.run_id = "test_run_1" + mock_ti.map_index = 2 # Simulating a mapped task instance + mock_ti.task = operator + + context = {"ti": mock_ti} + + mock_session = MagicMock() + mock_query = MagicMock() + mock_filter = MagicMock() + + mock_session.query.return_value = mock_query + mock_query.filter.return_value = mock_filter + mock_filter.delete.return_value = None + + with patch("airflow.utils.session.provide_session") as mock_provide_session: + + def session_decorator(func): + def wrapper(*args, **kwargs): + return func(session=mock_session) + + return wrapper + + mock_provide_session.side_effect = session_decorator + + operator._override_rtif(context) + + mock_session.query.assert_called_once_with(RenderedTaskInstanceFields) + mock_query.filter.assert_called_once() + filter_call_arg_map_index = str(mock_query.filter.call_args.args[-1]) + assert filter_call_arg_map_index == "rendered_task_instance_fields.map_index = :map_index_1" + + def test_dbt_cmd_flags_templating(): """Test that dbt_cmd_flags supports Jinja templating.""" from datetime import datetime