Skip to content

Commit

Permalink
Ensure teardown tasks are executed when DAG run is set to failed
Browse files Browse the repository at this point in the history
  • Loading branch information
jscheffl committed Jan 9, 2025
1 parent a3c51cd commit b9d8425
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 12 deletions.
28 changes: 16 additions & 12 deletions airflow/api/common/mark_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,10 +280,6 @@ def set_dag_run_state_to_failed(
if not run_id:
raise ValueError(f"Invalid dag_run_id: {run_id}")

# Mark the dag run to failed.
if commit:
_set_dag_run_state(dag.dag_id, run_id, DagRunState.FAILED, session)

running_states = (
TaskInstanceState.RUNNING,
TaskInstanceState.DEFERRED,
Expand All @@ -292,7 +288,7 @@ def set_dag_run_state_to_failed(

# Mark only RUNNING task instances.
task_ids = [task.task_id for task in dag.tasks]
tis = session.scalars(
running_tis: list[TaskInstance] = session.scalars(
select(TaskInstance).where(
TaskInstance.dag_id == dag.dag_id,
TaskInstance.run_id == run_id,
Expand All @@ -301,16 +297,17 @@ def set_dag_run_state_to_failed(
)
)

task_ids_of_running_tis = [task_instance.task_id for task_instance in tis]
# Do not kill teardown tasks
task_ids_of_running_tis = [ti.task_id for ti in running_tis if not dag.task_dict[ti.task_id].is_teardown]

tasks = []
running_tasks = []
for task in dag.tasks:
if task.task_id in task_ids_of_running_tis:
task.dag = dag
tasks.append(task)
running_tasks.append(task)

# Mark non-finished tasks as SKIPPED.
tis = session.scalars(
pending_tis: list[TaskInstance] = session.scalars(
select(TaskInstance).filter(
TaskInstance.dag_id == dag.dag_id,
TaskInstance.run_id == run_id,
Expand All @@ -324,12 +321,19 @@ def set_dag_run_state_to_failed(
)
).all()

# Do not skip teardown tasks
pending_tis = [ti for ti in pending_tis if not dag.task_dict[ti.task_id].is_teardown]

if commit:
for ti in tis:
for ti in pending_tis:
ti.set_state(TaskInstanceState.SKIPPED)

return tis + set_state(
tasks=tasks,
# Mark the dag run to failed if there is no pending teardown (else this would not be scheduled later).
if not any(dag.task_dict[ti.task_id].is_teardown for ti in (running_tis + pending_tis)):
_set_dag_run_state(dag.dag_id, run_id, DagRunState.FAILED, session)

return pending_tis + set_state(
tasks=running_tasks,
run_id=run_id,
state=TaskInstanceState.FAILED,
commit=commit,
Expand Down
54 changes: 54 additions & 0 deletions tests/api/common/test_mark_tasks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

from typing import TYPE_CHECKING

import pytest

from airflow.api.common.mark_tasks import set_dag_run_state_to_failed
from airflow.operators.empty import EmptyOperator
from airflow.utils.state import TaskInstanceState

if TYPE_CHECKING:
from airflow.models.taskinstance import TaskInstance

from tests_common.pytest_plugin import DagMaker

pytestmark = pytest.mark.db_test


def test_set_dag_run_state_to_failed(dag_maker: DagMaker):
with dag_maker("TEST_DAG_1"):
with EmptyOperator(task_id="teardown").as_teardown():
EmptyOperator(task_id="running")
EmptyOperator(task_id="pending")
dr = dag_maker.create_dagrun()
for ti in dr.get_task_instances():
if ti.task_id == "running":
ti.set_state(TaskInstanceState.RUNNING)
dag_maker.session.flush()
assert dr.dag

updated_tis: list[TaskInstance] = set_dag_run_state_to_failed(
dag=dr.dag, run_id=dr.run_id, commit=True, session=dag_maker.session
)
assert len(updated_tis) == 2
task_dict = {ti.task_id: ti for ti in updated_tis}
assert task_dict["running"].state == TaskInstanceState.FAILED
assert task_dict["pending"].state == TaskInstanceState.SKIPPED
assert "teardown" not in task_dict

0 comments on commit b9d8425

Please sign in to comment.