diff --git a/cosmos/airflow/graph.py b/cosmos/airflow/graph.py index 95c9d242e4..2cbba41141 100644 --- a/cosmos/airflow/graph.py +++ b/cosmos/airflow/graph.py @@ -641,14 +641,17 @@ def _add_dbt_setup_async_task( tasks_map[DBT_SETUP_ASYNC_TASK_ID] = setup_airflow_task -def _add_producer_watcher_and_dependencies( +def _add_watcher_producer_task( dag: DAG, task_args: dict[str, Any], tasks_map: dict[str, Any], task_group: TaskGroup | None, render_config: RenderConfig | None = None, - nodes: dict[str, DbtNode] | None = None, -) -> str: +) -> BaseOperator: + """ + Create the producer task for the watcher execution mode and add it to the tasks_map. + The producer task is the task that will be used to produce the events for the watcher execution mode. + """ producer_task_args = task_args.copy() if render_config is not None: @@ -669,10 +672,23 @@ def _add_producer_watcher_and_dependencies( arguments=producer_task_args, ) producer_airflow_task = create_airflow_task(producer_task_metadata, dag, task_group=task_group) + tasks_map[PRODUCER_WATCHER_TASK_ID] = producer_airflow_task + return producer_airflow_task - # Second, we need to set the producer task ID in all consumer tasks (and their children tasks) - for node_id, task_or_taskgroup in tasks_map.items(): +def _add_watcher_dependencies( + dag: DAG, + producer_airflow_task: BaseOperator, + task_args: dict[str, Any], + tasks_map: dict[str, Any], + nodes: dict[str, DbtNode] | None = None, +) -> str: + """ + Iterate through the watcher consumer tasks and: + - set the producer task ID in all of them + - make the producer task to be the parent of the root dbt nodes, without blocking them from sensing XCom + """ + for node_id, task_or_taskgroup in tasks_map.items(): node_tasks = ( list(task_or_taskgroup.children.values()) if isinstance(task_or_taskgroup, TaskGroup) @@ -681,7 +697,7 @@ def _add_producer_watcher_and_dependencies( for task in node_tasks: task.producer_task_id = producer_airflow_task.task_id # type: ignore[attr-defined] - # Third, we want to make the producer task to be the parent of the root dbt nodes, without blocking them from sensing XCom + # Make the producer task to be the parent of the root dbt nodes, without blocking them from sensing XCom # We only managed to do this in the case of DbtDag. # The way it is implemented is by setting the trigger_rule to "always" for the consumer tasks, and by having the producer task with a high priority_weight. if "DbtDag" in dag.__class__.__name__: @@ -702,7 +718,6 @@ def _add_producer_watcher_and_dependencies( for task in always_run_tasks: task.trigger_rule = task_args.get("trigger_rule", "always") # type: ignore[attr-defined] - tasks_map[PRODUCER_WATCHER_TASK_ID] = producer_airflow_task return producer_airflow_task.task_id @@ -844,6 +859,18 @@ def build_airflow_graph( # noqa: C901 TODO: https://github.com/astronomer/astro if execution_mode == ExecutionMode.AIRFLOW_ASYNC: # This property is only relevant for the setup task, not the other tasks: virtualenv_dir = task_args.pop("virtualenv_dir", None) + elif execution_mode == ExecutionMode.WATCHER: + setup_operator_args = getattr(execution_config, "setup_operator_args", None) or {} + # We are intentionally creating the producer task ahead of the consumer tasks + # Airflow priority weight is not being respected in multiple versions of the library, including 3.1 + # To instantiate the producer before helps having it before on the DAG topological order and scheduling this task before the consumer tasks + producer_task = _add_watcher_producer_task( + dag=dag, + task_args={**task_args, **setup_operator_args}, + tasks_map=tasks_map, + task_group=task_group, + render_config=render_config, + ) for node_id, node in nodes.items(): task_or_group_args = { @@ -915,12 +942,11 @@ def build_airflow_graph( # noqa: C901 TODO: https://github.com/astronomer/astro if execution_mode == ExecutionMode.WATCHER: setup_operator_args = getattr(execution_config, "setup_operator_args", None) or {} - _add_producer_watcher_and_dependencies( + _add_watcher_dependencies( dag=dag, - task_args={**task_args, **setup_operator_args}, + producer_airflow_task=producer_task, + task_args=task_args, tasks_map=tasks_map, - task_group=task_group, - render_config=render_config, nodes=nodes, ) diff --git a/tests/operators/test_watcher.py b/tests/operators/test_watcher.py index 04bdcf0518..c56d5e87c2 100644 --- a/tests/operators/test_watcher.py +++ b/tests/operators/test_watcher.py @@ -1102,10 +1102,10 @@ def test_dbt_task_group_with_watcher(): expected_task_names = [ "pre_dbt", + "dbt_task_group.dbt_producer_watcher", "dbt_task_group.raw_customers_seed", "dbt_task_group.raw_orders_seed", "dbt_task_group.raw_payments_seed", - "dbt_task_group.dbt_producer_watcher", "dbt_task_group.stg_customers_run", "dbt_task_group.stg_orders_run", "dbt_task_group.stg_payments_run",