Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 79 additions & 46 deletions cosmos/airflow/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -929,7 +929,71 @@ def _add_teardown_task(
tasks_map[DBT_TEARDOWN_ASYNC_TASK_ID] = teardown_airflow_task


def build_airflow_graph( # noqa: C901 TODO: https://github.com/astronomer/astronomer-cosmos/issues/1943
def _add_test_tasks(
render_config: RenderConfig,
dbt_project_name: str,
execution_mode: ExecutionMode,
test_indirect_selection: TestIndirectSelection,
task_args: dict[str, Any],
on_warning_callback: Callable[..., Any] | None,
task_or_group_args: dict[str, Any],
parent_task_group: TaskGroup | None,
tasks_map: dict[str, TaskGroup | BaseOperator],
nodes: dict[str, DbtNode],
detached_nodes: dict[str, DbtNode],
) -> None:
"""
Add test tasks to tasks_map based on the render_config test_behavior setting.

Handles AFTER_ALL (single test task after DAG leaves) and BUILD/AFTER_EACH
(individual detached test tasks per node).
"""
if render_config.test_behavior == TestBehavior.AFTER_ALL:
test_meta = create_test_task_metadata(
f"{dbt_project_name}_test",
execution_mode,
test_indirect_selection,
task_args=task_args,
on_warning_callback=on_warning_callback,
render_config=render_config,
enable_owner_inheritance=render_config.enable_owner_inheritance,
)
test_task_args = {
**task_or_group_args,
"task_meta": test_meta,
"resource_type": DbtResourceType.TEST,
}
# AFTER_ALL test is a single DAG-level task: place it at root so task_id stays e.g. "astro_shop_test"
test_task_args["task_group"] = parent_task_group
test_task = generate_or_convert_task(**test_task_args) # type: ignore[arg-type]
leaves_ids = calculate_leaves(tasks_ids=list(tasks_map.keys()), nodes=nodes)
for leaf_node_id in leaves_ids:
tasks_map[leaf_node_id] >> test_task
tasks_map[f"{dbt_project_name}_test"] = test_task
elif render_config.test_behavior in (TestBehavior.BUILD, TestBehavior.AFTER_EACH):
# Handle detached test nodes
for node_id, node in detached_nodes.items():
datached_node_name = calculate_detached_node_name(node)
test_meta = create_test_task_metadata(
datached_node_name,
execution_mode,
test_indirect_selection,
task_args=task_args,
on_warning_callback=on_warning_callback,
render_config=render_config,
node=node,
enable_owner_inheritance=render_config.enable_owner_inheritance,
)
test_task_args = {
**task_or_group_args,
"task_meta": test_meta,
"resource_type": node.resource_type,
}
test_task = generate_or_convert_task(**test_task_args) # type: ignore[arg-type]
tasks_map[node_id] = test_task


def build_airflow_graph(
nodes: dict[str, DbtNode],
dag: DAG, # Airflow-specific - parent DAG where to associate tasks and (optional) task groups
execution_mode: ExecutionMode, # Cosmos-specific - decide what which class to use
Expand Down Expand Up @@ -1000,6 +1064,7 @@ def build_airflow_graph( # noqa: C901 TODO: https://github.com/astronomer/astro
tests_per_model=tests_per_model,
)

task_or_group_args: dict[str, Any] = {}
for node_id, node in nodes.items():
task_group = (
create_task_groups_based_on_folder(dag, node, parent_task_group, task_groups)
Expand Down Expand Up @@ -1029,51 +1094,19 @@ def build_airflow_graph( # noqa: C901 TODO: https://github.com/astronomer/astro
tasks_map[node_id] = task_or_group
task_group = parent_task_group

# If test_behaviour=="after_all", there will be one test task, run by the end of the DAG
# The end of a DAG is defined by the DAG leaf tasks (tasks which do not have downstream tasks)
if render_config.test_behavior == TestBehavior.AFTER_ALL:
test_meta = create_test_task_metadata(
f"{dbt_project_name}_test",
execution_mode,
test_indirect_selection,
task_args=task_args,
on_warning_callback=on_warning_callback,
render_config=render_config,
enable_owner_inheritance=render_config.enable_owner_inheritance,
)
test_task_args = {
**task_or_group_args,
"task_meta": test_meta,
"resource_type": DbtResourceType.TEST,
}
# AFTER_ALL test is a single DAG-level task: place it at root so task_id stays e.g. "astro_shop_test"
test_task_args["task_group"] = parent_task_group
test_task = generate_or_convert_task(**test_task_args) # type: ignore[arg-type]
leaves_ids = calculate_leaves(tasks_ids=list(tasks_map.keys()), nodes=nodes)
for leaf_node_id in leaves_ids:
tasks_map[leaf_node_id] >> test_task
tasks_map[f"{dbt_project_name}_test"] = test_task
elif render_config.test_behavior in (TestBehavior.BUILD, TestBehavior.AFTER_EACH):
# Handle detached test nodes
for node_id, node in detached_nodes.items():
datached_node_name = calculate_detached_node_name(node)
test_meta = create_test_task_metadata(
datached_node_name,
execution_mode,
test_indirect_selection,
task_args=task_args,
on_warning_callback=on_warning_callback,
render_config=render_config,
node=node,
enable_owner_inheritance=render_config.enable_owner_inheritance,
)
test_task_args = {
**task_or_group_args,
"task_meta": test_meta,
"resource_type": node.resource_type,
}
test_task = generate_or_convert_task(**test_task_args) # type: ignore[arg-type]
tasks_map[node_id] = test_task
_add_test_tasks(
render_config=render_config,
dbt_project_name=dbt_project_name,
execution_mode=execution_mode,
test_indirect_selection=test_indirect_selection,
task_args=task_args,
on_warning_callback=on_warning_callback,
task_or_group_args=task_or_group_args,
parent_task_group=parent_task_group,
tasks_map=tasks_map,
nodes=nodes,
detached_nodes=detached_nodes,
)

create_airflow_task_dependencies(nodes, tasks_map)

Expand Down