diff --git a/airflow/models/dag.py b/airflow/models/dag.py index 428f79e14e67a..d693f41931797 100644 --- a/airflow/models/dag.py +++ b/airflow/models/dag.py @@ -17,7 +17,8 @@ # under the License. from __future__ import annotations -import collections.abc +import asyncio +import collections import copy import functools import itertools @@ -82,11 +83,11 @@ from airflow.exceptions import ( AirflowDagInconsistent, AirflowException, - AirflowSkipException, DuplicateTaskIdFound, FailStopDagInvalidTriggerRule, ParamValidationError, RemovedInAirflow3Warning, + TaskDeferred, TaskNotFound, ) from airflow.jobs.job import run_job @@ -101,7 +102,6 @@ Context, TaskInstance, TaskInstanceKey, - TaskReturnCode, clear_task_instances, ) from airflow.secrets.local_filesystem import LocalFilesystemBackend @@ -285,12 +285,11 @@ def get_dataset_triggered_next_run_info( } -class _StopDagTest(Exception): - """ - Raise when DAG.test should stop immediately. +def _triggerer_is_healthy(): + from airflow.jobs.triggerer_job_runner import TriggererJobRunner - :meta private: - """ + job = TriggererJobRunner.most_recent_job() + return job and job.is_alive() @functools.total_ordering @@ -2844,21 +2843,12 @@ def add_logger_if_needed(ti: TaskInstance): if not scheduled_tis and ids_unrunnable: self.log.warning("No tasks to run. unrunnable tasks: %s", ids_unrunnable) time.sleep(1) + triggerer_running = _triggerer_is_healthy() for ti in scheduled_tis: try: add_logger_if_needed(ti) ti.task = tasks[ti.task_id] - ret = _run_task(ti, session=session) - if ret is TaskReturnCode.DEFERRED: - if not _triggerer_is_healthy(): - raise _StopDagTest( - "Task has deferred but triggerer component is not running. " - "You can start the triggerer by running `airflow triggerer` in a terminal." - ) - except _StopDagTest: - # Let this exception bubble out and not be swallowed by the - # except block below. - raise + _run_task(ti=ti, inline_trigger=not triggerer_running, session=session) except Exception: self.log.exception("Task failed; ti=%s", ti) if conn_file_path or variable_file_path: @@ -3988,14 +3978,15 @@ def get_current_dag(cls) -> DAG | None: return None -def _triggerer_is_healthy(): - from airflow.jobs.triggerer_job_runner import TriggererJobRunner +def _run_trigger(trigger): + async def _run_trigger_main(): + async for event in trigger.run(): + return event - job = TriggererJobRunner.most_recent_job() - return job and job.is_alive() + return asyncio.run(_run_trigger_main()) -def _run_task(ti: TaskInstance, session) -> TaskReturnCode | None: +def _run_task(*, ti: TaskInstance, inline_trigger: bool = False, session: Session): """ Run a single task instance, and push result to Xcom for downstream tasks. @@ -4005,20 +3996,21 @@ def _run_task(ti: TaskInstance, session) -> TaskReturnCode | None: Args: ti: TaskInstance to run """ - ret = None - log.info("*****************************************************") - if ti.map_index > 0: - log.info("Running task %s index %d", ti.task_id, ti.map_index) - else: - log.info("Running task %s", ti.task_id) - try: - ret = ti._run_raw_task(session=session) - session.flush() - log.info("%s ran successfully!", ti.task_id) - except AirflowSkipException: - log.info("Task Skipped, continuing") - log.info("*****************************************************") - return ret + log.info("[DAG TEST] starting task_id=%s map_index=%s", ti.task_id, ti.map_index) + while True: + try: + log.info("[DAG TEST] running task %s", ti) + ti._run_raw_task(session=session, raise_on_defer=inline_trigger) + break + except TaskDeferred as e: + log.info("[DAG TEST] running trigger in line") + event = _run_trigger(e.trigger) + ti.next_method = e.method_name + ti.next_kwargs = {"event": event.payload} if event else e.kwargs + log.info("[DAG TEST] Trigger completed") + session.merge(ti) + session.commit() + log.info("[DAG TEST] end task task_id=%s map_index=%s", ti.task_id, ti.map_index) def _get_or_create_dagrun( diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index fe09bba732c4f..48146724b55d1 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -2207,6 +2207,7 @@ def _run_raw_task( test_mode: bool = False, job_id: str | None = None, pool: str | None = None, + raise_on_defer: bool = False, session: Session = NEW_SESSION, ) -> TaskReturnCode | None: """ @@ -2261,6 +2262,8 @@ def _run_raw_task( except TaskDeferred as defer: # The task has signalled it wants to defer execution based on # a trigger. + if raise_on_defer: + raise self._defer_task(defer=defer, session=session) self.log.info( "Pausing task as DEFERRED. dag_id=%s, task_id=%s, execution_date=%s, start_date=%s", diff --git a/tests/cli/commands/test_dag_command.py b/tests/cli/commands/test_dag_command.py index 78b7fd4525f2d..30b5c475ea4a9 100644 --- a/tests/cli/commands/test_dag_command.py +++ b/tests/cli/commands/test_dag_command.py @@ -37,9 +37,10 @@ from airflow.exceptions import AirflowException from airflow.models import DagBag, DagModel, DagRun from airflow.models.baseoperator import BaseOperator -from airflow.models.dag import _StopDagTest +from airflow.models.dag import _run_trigger from airflow.models.serialized_dag import SerializedDagModel -from airflow.triggers.temporal import TimeDeltaTrigger +from airflow.triggers.base import TriggerEvent +from airflow.triggers.temporal import DateTimeTrigger, TimeDeltaTrigger from airflow.utils import timezone from airflow.utils.session import create_session from airflow.utils.types import DagRunType @@ -824,35 +825,47 @@ def test_dag_test_with_custom_timetable(self, mock__get_or_create_dagrun, _): dag_command.dag_test(cli_args) assert "data_interval" in mock__get_or_create_dagrun.call_args.kwargs - def test_dag_test_no_triggerer(self, dag_maker): - with dag_maker() as dag: - - @task - def one(): - return 1 - - @task - def two(val): - return val + 1 - - class MyOp(BaseOperator): - template_fields = ("tfield",) - - def __init__(self, tfield, **kwargs): - self.tfield = tfield - super().__init__(**kwargs) - - def execute(self, context, event=None): - if event is None: - print("I AM DEFERRING") - self.defer(trigger=TimeDeltaTrigger(timedelta(seconds=20)), method_name="execute") - return - print("RESUMING") - return self.tfield + 1 - - task_one = one() - task_two = two(task_one) - op = MyOp(task_id="abc", tfield=str(task_two)) - task_two >> op - with pytest.raises(_StopDagTest, match="Task has deferred but triggerer component is not running"): - dag.test() + def test_dag_test_run_trigger(self, dag_maker): + now = timezone.utcnow() + trigger = DateTimeTrigger(moment=now) + e = _run_trigger(trigger) + assert isinstance(e, TriggerEvent) + assert e.payload == now + + def test_dag_test_no_triggerer_running(self, dag_maker): + with mock.patch("airflow.models.dag._run_trigger", wraps=_run_trigger) as mock_run: + with dag_maker() as dag: + + @task + def one(): + return 1 + + @task + def two(val): + return val + 1 + + trigger = TimeDeltaTrigger(timedelta(seconds=0)) + + class MyOp(BaseOperator): + template_fields = ("tfield",) + + def __init__(self, tfield, **kwargs): + self.tfield = tfield + super().__init__(**kwargs) + + def execute(self, context, event=None): + if event is None: + print("I AM DEFERRING") + self.defer(trigger=trigger, method_name="execute") + return + print("RESUMING") + return self.tfield + 1 + + task_one = one() + task_two = two(task_one) + op = MyOp(task_id="abc", tfield=task_two) + task_two >> op + dr = dag.test() + assert mock_run.call_args_list[0] == ((trigger,), {}) + tis = dr.get_task_instances() + assert [x for x in tis if x.task_id == "abc"][0].state == "success" diff --git a/tests/models/test_mappedoperator.py b/tests/models/test_mappedoperator.py index 5c2e23c1f9e30..e016125ea8ba0 100644 --- a/tests/models/test_mappedoperator.py +++ b/tests/models/test_mappedoperator.py @@ -95,7 +95,7 @@ def execute(self, context: Context): mapped = CustomOperator.partial(task_id="task_2").expand(arg=unrenderable_values) task1 >> mapped dag.test() - assert caplog.text.count("task_2 ran successfully") == 2 + assert caplog.text.count("[DAG TEST] end task task_id=task_2") == 2 assert ( "Unable to check if the value of type 'UnrenderableClass' is False for task 'task_2', field 'arg'" in caplog.text