Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 37 additions & 11 deletions cosmos/airflow/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -641,14 +641,17 @@ def _add_dbt_setup_async_task(
tasks_map[DBT_SETUP_ASYNC_TASK_ID] = setup_airflow_task


def _add_producer_watcher_and_dependencies(
def _add_watcher_producer_task(
dag: DAG,
task_args: dict[str, Any],
tasks_map: dict[str, Any],
task_group: TaskGroup | None,
render_config: RenderConfig | None = None,
nodes: dict[str, DbtNode] | None = None,
) -> str:
) -> BaseOperator:
"""
Create the producer task for the watcher execution mode and add it to the tasks_map.
The producer task is the task that will be used to produce the events for the watcher execution mode.
"""
producer_task_args = task_args.copy()

if render_config is not None:
Expand All @@ -669,10 +672,23 @@ def _add_producer_watcher_and_dependencies(
arguments=producer_task_args,
)
producer_airflow_task = create_airflow_task(producer_task_metadata, dag, task_group=task_group)
tasks_map[PRODUCER_WATCHER_TASK_ID] = producer_airflow_task
return producer_airflow_task

# Second, we need to set the producer task ID in all consumer tasks (and their children tasks)
for node_id, task_or_taskgroup in tasks_map.items():

def _add_watcher_dependencies(
dag: DAG,
producer_airflow_task: BaseOperator,
task_args: dict[str, Any],
tasks_map: dict[str, Any],
nodes: dict[str, DbtNode] | None = None,
) -> str:
"""
Iterate through the watcher consumer tasks and:
- set the producer task ID in all of them
- make the producer task to be the parent of the root dbt nodes, without blocking them from sensing XCom
"""
for node_id, task_or_taskgroup in tasks_map.items():
node_tasks = (
list(task_or_taskgroup.children.values())
if isinstance(task_or_taskgroup, TaskGroup)
Expand All @@ -681,7 +697,7 @@ def _add_producer_watcher_and_dependencies(
for task in node_tasks:
task.producer_task_id = producer_airflow_task.task_id # type: ignore[attr-defined]

# Third, we want to make the producer task to be the parent of the root dbt nodes, without blocking them from sensing XCom
# Make the producer task to be the parent of the root dbt nodes, without blocking them from sensing XCom
# We only managed to do this in the case of DbtDag.
# The way it is implemented is by setting the trigger_rule to "always" for the consumer tasks, and by having the producer task with a high priority_weight.
if "DbtDag" in dag.__class__.__name__:
Expand All @@ -702,7 +718,6 @@ def _add_producer_watcher_and_dependencies(
for task in always_run_tasks:
task.trigger_rule = task_args.get("trigger_rule", "always") # type: ignore[attr-defined]

tasks_map[PRODUCER_WATCHER_TASK_ID] = producer_airflow_task
return producer_airflow_task.task_id


Expand Down Expand Up @@ -844,6 +859,18 @@ def build_airflow_graph( # noqa: C901 TODO: https://github.com/astronomer/astro
if execution_mode == ExecutionMode.AIRFLOW_ASYNC:
# This property is only relevant for the setup task, not the other tasks:
virtualenv_dir = task_args.pop("virtualenv_dir", None)
elif execution_mode == ExecutionMode.WATCHER:
setup_operator_args = getattr(execution_config, "setup_operator_args", None) or {}
# We are intentionally creating the producer task ahead of the consumer tasks
# Airflow priority weight is not being respected in multiple versions of the library, including 3.1
# To instantiate the producer before helps having it before on the DAG topological order and scheduling this task before the consumer tasks
producer_task = _add_watcher_producer_task(
Comment thread
tatiana marked this conversation as resolved.
dag=dag,
task_args={**task_args, **setup_operator_args},
tasks_map=tasks_map,
task_group=task_group,
render_config=render_config,
)

for node_id, node in nodes.items():
task_or_group_args = {
Expand Down Expand Up @@ -915,12 +942,11 @@ def build_airflow_graph( # noqa: C901 TODO: https://github.com/astronomer/astro

if execution_mode == ExecutionMode.WATCHER:
setup_operator_args = getattr(execution_config, "setup_operator_args", None) or {}
_add_producer_watcher_and_dependencies(
_add_watcher_dependencies(
dag=dag,
task_args={**task_args, **setup_operator_args},
producer_airflow_task=producer_task,
task_args=task_args,
tasks_map=tasks_map,
task_group=task_group,
render_config=render_config,
nodes=nodes,
)

Expand Down
2 changes: 1 addition & 1 deletion tests/operators/test_watcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -1102,10 +1102,10 @@ def test_dbt_task_group_with_watcher():

expected_task_names = [
"pre_dbt",
"dbt_task_group.dbt_producer_watcher",
"dbt_task_group.raw_customers_seed",
"dbt_task_group.raw_orders_seed",
"dbt_task_group.raw_payments_seed",
"dbt_task_group.dbt_producer_watcher",
"dbt_task_group.stg_customers_run",
"dbt_task_group.stg_orders_run",
"dbt_task_group.stg_payments_run",
Expand Down