Skip to content
50 changes: 37 additions & 13 deletions cosmos/airflow/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -751,6 +751,7 @@ def _add_watcher_producer_task(
task_id=PRODUCER_WATCHER_DONE_TASK_ID,
)
producer_airflow_task >> producer_done_task
tasks_map[PRODUCER_WATCHER_DONE_TASK_ID] = producer_done_task

return producer_airflow_task

Expand All @@ -766,7 +767,15 @@ def _add_watcher_dependencies(
Iterate through the watcher consumer tasks and:
- set the producer task ID in all of them
- make the producer task to be the parent of the root dbt nodes, without blocking them from sensing XCom
- if the producer task has depends_on_past=True, set wait_for_downstream=True in the producer and all its
watcher tasks, and put them upstream of the producer_watcher_done task so the entire task group
behaves as a single unit that needs to succeed for the next run to start.
"""
producer_watcher_done_task = tasks_map.get(PRODUCER_WATCHER_DONE_TASK_ID)
needs_wait_for_downstream = producer_watcher_done_task is not None and producer_airflow_task.depends_on_past
if needs_wait_for_downstream:
producer_airflow_task.wait_for_downstream = True
Comment thread
johnhoran marked this conversation as resolved.

for node_id, task_or_taskgroup in tasks_map.items():
# We do not want to set a dependency between the producer task (or its gate) and itself
if node_id in (PRODUCER_WATCHER_TASK_ID, PRODUCER_WATCHER_DONE_TASK_ID):
Expand All @@ -778,26 +787,40 @@ def _add_watcher_dependencies(
else [task_or_taskgroup]
)
for task in node_tasks:
task.producer_task_id = producer_airflow_task.task_id # type: ignore[union-attr]
if hasattr(task, "producer_task_id"):
task.producer_task_id = producer_airflow_task.task_id # type: ignore[union-attr]
if needs_wait_for_downstream:
task.wait_for_downstream = True # type: ignore[union-attr]
Comment thread
johnhoran marked this conversation as resolved.

if needs_wait_for_downstream and not task_or_taskgroup.downstream_task_ids:
task_or_taskgroup >> producer_watcher_done_task

# Make the producer task to be the parent of the root dbt nodes, without blocking them from sensing XCom
# We only managed to do this in the case of DbtDag.
# The way it is implemented is by setting the trigger_rule to "always" for the consumer tasks, and by having the producer task with a high priority_weight.
if "DbtDag" in dag.__class__.__name__:

if (
"DbtDag" in dag.__class__.__name__
and nodes
and node_id in nodes
and not set(nodes[node_id].depends_on).intersection(nodes)
):
# Is this dbt node a root of the (subset of the) dbt project?
# Note: this may happen in one scenarios:
# - the dbt node not having any `depends_on` within the user-selected `nodes`
if nodes and node_id in nodes and not set(nodes[node_id].depends_on).intersection(nodes):
producer_airflow_task >> task_or_taskgroup
if isinstance(task_or_taskgroup, TaskGroup):
taskgroup = task_or_taskgroup
always_run_tasks = [
task for task in node_tasks if not set(task.upstream_task_ids).intersection(taskgroup.children)
]
else:
always_run_tasks = [task_or_taskgroup]
for task in always_run_tasks:
task.trigger_rule = task_args.get("trigger_rule", "always") # type: ignore[union-attr]
producer_airflow_task >> task_or_taskgroup
always_run_tasks = (
[
task
for task in node_tasks
if not set(task.upstream_task_ids).intersection(task_or_taskgroup.children)
]
if isinstance(task_or_taskgroup, TaskGroup)
else [task_or_taskgroup]
)

for task in always_run_tasks:
task.trigger_rule = task_args.get("trigger_rule", "always") # type: ignore[union-attr]


def should_create_detached_nodes(render_config: RenderConfig) -> bool:
Expand Down Expand Up @@ -1029,6 +1052,7 @@ def build_airflow_graph( # noqa: C901 TODO: https://github.com/astronomer/astro
leaves_ids = calculate_leaves(tasks_ids=list(tasks_map.keys()), nodes=nodes)
for leaf_node_id in leaves_ids:
tasks_map[leaf_node_id] >> test_task
tasks_map[f"{dbt_project_name}_test"] = test_task
elif render_config.test_behavior in (TestBehavior.BUILD, TestBehavior.AFTER_EACH):
# Handle detached test nodes
for node_id, node in detached_nodes.items():
Expand Down
75 changes: 75 additions & 0 deletions tests/airflow/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -2228,3 +2228,78 @@ def test_add_watcher_producer_task_passes_freshness_callback_via_setup_operator_

task_metadata = mock_create_task.call_args[0][0]
assert task_metadata.arguments["freshness_callback"] is my_callback


@pytest.mark.parametrize("depends_on_past", [False, True])
@pytest.mark.parametrize("test_behavior", [TestBehavior.NONE, TestBehavior.AFTER_EACH, TestBehavior.AFTER_ALL])
def test_watcher_dependency_wiring(test_behavior, depends_on_past):
with DAG("test-id", start_date=datetime(2022, 1, 1), default_args={"depends_on_past": depends_on_past}) as dag:
task_args = {
"project_dir": SAMPLE_PROJ_PATH,
"conn_id": "fake_conn",
"profile_config": ProfileConfig(
profile_name="default",
target_name="default",
profile_mapping=PostgresUserPasswordProfileMapping(
conn_id="fake_conn",
profile_args={"schema": "public"},
),
),
}

child_2b = DbtNode(
unique_id=f"{DbtResourceType.MODEL.value}.{SAMPLE_PROJ_PATH.stem}.child2.v2_b",
resource_type=DbtResourceType.MODEL,
depends_on=[parent_node.unique_id],
path_base=SAMPLE_PROJ_PATH,
original_file_path=Path("gen3/models/child2_v2.sql"),
tags=["nightly"],
config={"materialized": "table", "meta": {"cosmos": {"operator_kwargs": {"pool": "custom_pool"}}}},
has_test=True,
has_non_detached_test=True,
)
child_2b_test = DbtNode(
unique_id=f"{DbtResourceType.TEST.value}.{SAMPLE_PROJ_PATH.stem}.child2.test_v2_b",
resource_type=DbtResourceType.TEST,
depends_on=[child_2b.unique_id],
path_base=Path("."),
original_file_path=Path("."),
)

build_airflow_graph(
nodes={child_2b.unique_id: child_2b, child_2b_test.unique_id: child_2b_test, **sample_nodes},
dag=dag,
execution_mode=ExecutionMode.WATCHER,
test_indirect_selection=TestIndirectSelection.EAGER,
task_args=task_args,
render_config=RenderConfig(
test_behavior=test_behavior,
),
dbt_project_name="astro_shop",
task_group=TaskGroup("tg", dag=dag),
)
Comment thread
johnhoran marked this conversation as resolved.
if not depends_on_past:
assert dag.task_dict["tg.dbt_producer_watcher_done"].upstream_task_ids == {"tg.dbt_producer_watcher"}
assert all(task.wait_for_downstream is False for task in dag.tasks)
return

assert all(task.wait_for_downstream is True for task in dag.tasks if task.task_id != "tg.dbt_producer_watcher_done")
if test_behavior == TestBehavior.NONE:
assert dag.task_dict["tg.dbt_producer_watcher_done"].upstream_task_ids == {
"tg.child_run",
"tg.dbt_producer_watcher",
"tg.child2_v2_run",
"tg.child2_v2_b_run",
}
if test_behavior == TestBehavior.AFTER_EACH:
assert dag.task_dict["tg.dbt_producer_watcher_done"].upstream_task_ids == {
"tg.child_run",
"tg.dbt_producer_watcher",
"tg.child2_v2_run",
"tg.child2_v2_b.test",
}
if test_behavior == TestBehavior.AFTER_ALL:
assert dag.task_dict["tg.dbt_producer_watcher_done"].upstream_task_ids == {
"tg.dbt_producer_watcher",
"tg.astro_shop_test",
}
Comment on lines +2250 to +2305
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new test covers the dependency topology well. A small extra assertion that tg.dbt_producer_watcher_done.wait_for_downstream is False would make the intended gate behaviour explicit.

2 changes: 1 addition & 1 deletion tests/test_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,7 +414,7 @@ def test_converter_creates_dag_with_test_with_multiple_parents_test_afterall():
)
tasks = converter.tasks_map

assert len(converter.tasks_map) == 3
assert len(converter.tasks_map) == 4

assert tasks["model.my_dbt_project.combined_model"].task_id == "combined_model_run"
assert tasks["model.my_dbt_project.model_a"].task_id == "model_a_run"
Expand Down
Loading