From c8b4a5eff1318608788f1673c8386d47cec5bbf0 Mon Sep 17 00:00:00 2001 From: Daniel Standish <15932138+dstandish@users.noreply.github.com> Date: Tue, 26 Sep 2023 23:45:46 -0700 Subject: [PATCH 1/3] Run triggers inline with dag test No need to have trigger running -- will just run them async. --- airflow/models/dag.py | 67 ++++----- airflow/models/taskinstance.py | 3 + tests/cli/commands/test_dag_command.py | 81 +++++----- tests/models/test_mappedoperator.py | 2 +- .../aws/sensors/test_s3_keys_unchanged.py | 140 ++++++++++++++++++ 5 files changed, 221 insertions(+), 72 deletions(-) create mode 100644 tests/providers/amazon/aws/sensors/test_s3_keys_unchanged.py diff --git a/airflow/models/dag.py b/airflow/models/dag.py index 428f79e14e67a..9abf9b0948edd 100644 --- a/airflow/models/dag.py +++ b/airflow/models/dag.py @@ -17,6 +17,8 @@ # under the License. from __future__ import annotations +import asyncio +import collections import collections.abc import copy import functools @@ -82,11 +84,11 @@ from airflow.exceptions import ( AirflowDagInconsistent, AirflowException, - AirflowSkipException, DuplicateTaskIdFound, FailStopDagInvalidTriggerRule, ParamValidationError, RemovedInAirflow3Warning, + TaskDeferred, TaskNotFound, ) from airflow.jobs.job import run_job @@ -101,7 +103,6 @@ Context, TaskInstance, TaskInstanceKey, - TaskReturnCode, clear_task_instances, ) from airflow.secrets.local_filesystem import LocalFilesystemBackend @@ -285,12 +286,11 @@ def get_dataset_triggered_next_run_info( } -class _StopDagTest(Exception): - """ - Raise when DAG.test should stop immediately. +def _triggerer_is_healthy(): + from airflow.jobs.triggerer_job_runner import TriggererJobRunner - :meta private: - """ + job = TriggererJobRunner.most_recent_job() + return job and job.is_alive() @functools.total_ordering @@ -2844,21 +2844,12 @@ def add_logger_if_needed(ti: TaskInstance): if not scheduled_tis and ids_unrunnable: self.log.warning("No tasks to run. unrunnable tasks: %s", ids_unrunnable) time.sleep(1) + triggerer_running = _triggerer_is_healthy() for ti in scheduled_tis: try: add_logger_if_needed(ti) ti.task = tasks[ti.task_id] - ret = _run_task(ti, session=session) - if ret is TaskReturnCode.DEFERRED: - if not _triggerer_is_healthy(): - raise _StopDagTest( - "Task has deferred but triggerer component is not running. " - "You can start the triggerer by running `airflow triggerer` in a terminal." - ) - except _StopDagTest: - # Let this exception bubble out and not be swallowed by the - # except block below. - raise + _run_task(ti=ti, inline_trigger=not triggerer_running, session=session) except Exception: self.log.exception("Task failed; ti=%s", ti) if conn_file_path or variable_file_path: @@ -3988,14 +3979,15 @@ def get_current_dag(cls) -> DAG | None: return None -def _triggerer_is_healthy(): - from airflow.jobs.triggerer_job_runner import TriggererJobRunner +def _run_trigger(trigger): + async def _run_trigger_main(): + async for event in trigger.run(): + return event - job = TriggererJobRunner.most_recent_job() - return job and job.is_alive() + return asyncio.run(_run_trigger_main()) -def _run_task(ti: TaskInstance, session) -> TaskReturnCode | None: +def _run_task(*, ti: TaskInstance, inline_trigger: bool = False, session: Session): """ Run a single task instance, and push result to Xcom for downstream tasks. @@ -4005,20 +3997,21 @@ def _run_task(ti: TaskInstance, session) -> TaskReturnCode | None: Args: ti: TaskInstance to run """ - ret = None - log.info("*****************************************************") - if ti.map_index > 0: - log.info("Running task %s index %d", ti.task_id, ti.map_index) - else: - log.info("Running task %s", ti.task_id) - try: - ret = ti._run_raw_task(session=session) - session.flush() - log.info("%s ran successfully!", ti.task_id) - except AirflowSkipException: - log.info("Task Skipped, continuing") - log.info("*****************************************************") - return ret + log.info("[DAG TEST] starting task_id=%s map_index=%s", ti.task_id, ti.map_index) + while True: + try: + log.info("[DAG TEST] running task %s", ti) + ti._run_raw_task(session=session, raise_on_defer=inline_trigger) + break + except TaskDeferred as e: + log.info("[DAG TEST] running trigger in line") + event = _run_trigger(e.trigger) + ti.next_method = e.method_name + ti.next_kwargs = {"event": event.payload} if event else e.kwargs + log.info("[DAG TEST] Trigger completed") + session.merge(ti) + session.commit() + log.info("[DAG TEST] end task task_id=%s map_index=%s", ti.task_id, ti.map_index) def _get_or_create_dagrun( diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index fe09bba732c4f..48146724b55d1 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -2207,6 +2207,7 @@ def _run_raw_task( test_mode: bool = False, job_id: str | None = None, pool: str | None = None, + raise_on_defer: bool = False, session: Session = NEW_SESSION, ) -> TaskReturnCode | None: """ @@ -2261,6 +2262,8 @@ def _run_raw_task( except TaskDeferred as defer: # The task has signalled it wants to defer execution based on # a trigger. + if raise_on_defer: + raise self._defer_task(defer=defer, session=session) self.log.info( "Pausing task as DEFERRED. dag_id=%s, task_id=%s, execution_date=%s, start_date=%s", diff --git a/tests/cli/commands/test_dag_command.py b/tests/cli/commands/test_dag_command.py index 78b7fd4525f2d..30b5c475ea4a9 100644 --- a/tests/cli/commands/test_dag_command.py +++ b/tests/cli/commands/test_dag_command.py @@ -37,9 +37,10 @@ from airflow.exceptions import AirflowException from airflow.models import DagBag, DagModel, DagRun from airflow.models.baseoperator import BaseOperator -from airflow.models.dag import _StopDagTest +from airflow.models.dag import _run_trigger from airflow.models.serialized_dag import SerializedDagModel -from airflow.triggers.temporal import TimeDeltaTrigger +from airflow.triggers.base import TriggerEvent +from airflow.triggers.temporal import DateTimeTrigger, TimeDeltaTrigger from airflow.utils import timezone from airflow.utils.session import create_session from airflow.utils.types import DagRunType @@ -824,35 +825,47 @@ def test_dag_test_with_custom_timetable(self, mock__get_or_create_dagrun, _): dag_command.dag_test(cli_args) assert "data_interval" in mock__get_or_create_dagrun.call_args.kwargs - def test_dag_test_no_triggerer(self, dag_maker): - with dag_maker() as dag: - - @task - def one(): - return 1 - - @task - def two(val): - return val + 1 - - class MyOp(BaseOperator): - template_fields = ("tfield",) - - def __init__(self, tfield, **kwargs): - self.tfield = tfield - super().__init__(**kwargs) - - def execute(self, context, event=None): - if event is None: - print("I AM DEFERRING") - self.defer(trigger=TimeDeltaTrigger(timedelta(seconds=20)), method_name="execute") - return - print("RESUMING") - return self.tfield + 1 - - task_one = one() - task_two = two(task_one) - op = MyOp(task_id="abc", tfield=str(task_two)) - task_two >> op - with pytest.raises(_StopDagTest, match="Task has deferred but triggerer component is not running"): - dag.test() + def test_dag_test_run_trigger(self, dag_maker): + now = timezone.utcnow() + trigger = DateTimeTrigger(moment=now) + e = _run_trigger(trigger) + assert isinstance(e, TriggerEvent) + assert e.payload == now + + def test_dag_test_no_triggerer_running(self, dag_maker): + with mock.patch("airflow.models.dag._run_trigger", wraps=_run_trigger) as mock_run: + with dag_maker() as dag: + + @task + def one(): + return 1 + + @task + def two(val): + return val + 1 + + trigger = TimeDeltaTrigger(timedelta(seconds=0)) + + class MyOp(BaseOperator): + template_fields = ("tfield",) + + def __init__(self, tfield, **kwargs): + self.tfield = tfield + super().__init__(**kwargs) + + def execute(self, context, event=None): + if event is None: + print("I AM DEFERRING") + self.defer(trigger=trigger, method_name="execute") + return + print("RESUMING") + return self.tfield + 1 + + task_one = one() + task_two = two(task_one) + op = MyOp(task_id="abc", tfield=task_two) + task_two >> op + dr = dag.test() + assert mock_run.call_args_list[0] == ((trigger,), {}) + tis = dr.get_task_instances() + assert [x for x in tis if x.task_id == "abc"][0].state == "success" diff --git a/tests/models/test_mappedoperator.py b/tests/models/test_mappedoperator.py index 5c2e23c1f9e30..e016125ea8ba0 100644 --- a/tests/models/test_mappedoperator.py +++ b/tests/models/test_mappedoperator.py @@ -95,7 +95,7 @@ def execute(self, context: Context): mapped = CustomOperator.partial(task_id="task_2").expand(arg=unrenderable_values) task1 >> mapped dag.test() - assert caplog.text.count("task_2 ran successfully") == 2 + assert caplog.text.count("[DAG TEST] end task task_id=task_2") == 2 assert ( "Unable to check if the value of type 'UnrenderableClass' is False for task 'task_2', field 'arg'" in caplog.text diff --git a/tests/providers/amazon/aws/sensors/test_s3_keys_unchanged.py b/tests/providers/amazon/aws/sensors/test_s3_keys_unchanged.py new file mode 100644 index 0000000000000..a90b2234cbe51 --- /dev/null +++ b/tests/providers/amazon/aws/sensors/test_s3_keys_unchanged.py @@ -0,0 +1,140 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from datetime import datetime +from unittest import mock + +import pytest +import time_machine + +from airflow.exceptions import AirflowException, AirflowSkipException +from airflow.models.dag import DAG +from airflow.providers.amazon.aws.sensors.s3 import S3KeysUnchangedSensor + +TEST_DAG_ID = "unit_tests_aws_sensor" +DEFAULT_DATE = datetime(2015, 1, 1) + + +class TestS3KeysUnchangedSensor: + def setup_method(self): + self.dag = DAG(f"{TEST_DAG_ID}test_schedule_dag_once", start_date=DEFAULT_DATE, schedule="@once") + + self.sensor = S3KeysUnchangedSensor( + task_id="sensor_1", + bucket_name="test-bucket", + prefix="test-prefix/path", + inactivity_period=12, + poke_interval=0.1, + min_objects=1, + allow_delete=True, + dag=self.dag, + ) + + def test_reschedule_mode_not_allowed(self): + with pytest.raises(ValueError): + S3KeysUnchangedSensor( + task_id="sensor_2", + bucket_name="test-bucket", + prefix="test-prefix/path", + poke_interval=0.1, + mode="reschedule", + dag=self.dag, + ) + + def test_render_template_fields(self): + S3KeysUnchangedSensor( + task_id="sensor_3", + bucket_name="test-bucket", + prefix="test-prefix/path", + inactivity_period=12, + poke_interval=0.1, + min_objects=1, + allow_delete=True, + dag=self.dag, + ).render_template_fields({}) + + @time_machine.travel(DEFAULT_DATE) + def test_files_deleted_between_pokes_throw_error(self): + self.sensor.allow_delete = False + self.sensor.is_keys_unchanged({"a", "b"}) + with pytest.raises(AirflowException): + self.sensor.is_keys_unchanged({"a"}) + + @pytest.mark.parametrize( + "current_objects, expected_returns, inactivity_periods", + [ + pytest.param( + ({"a"}, {"a", "b"}, {"a", "b", "c"}), + (False, False, False), + (0, 0, 0), + id="resetting inactivity period after key change", + ), + pytest.param( + ({"a", "b"}, {"a"}, {"a", "c"}), + (False, False, False), + (0, 0, 0), + id="item was deleted with option `allow_delete=True`", + ), + pytest.param( + ({"a"}, {"a"}, {"a"}), (False, False, True), (0, 10, 20), id="inactivity period was exceeded" + ), + pytest.param( + (set(), set(), set()), (False, False, False), (0, 10, 20), id="not pass if empty key is given" + ), + ], + ) + def test_key_changes(self, current_objects, expected_returns, inactivity_periods, time_machine): + time_machine.move_to(DEFAULT_DATE) + for current, expected, period in zip(current_objects, expected_returns, inactivity_periods): + assert self.sensor.is_keys_unchanged(current) == expected + assert self.sensor.inactivity_seconds == period + time_machine.coordinates.shift(10) + + @mock.patch("airflow.providers.amazon.aws.sensors.s3.S3Hook") + def test_poke_succeeds_on_upload_complete(self, mock_hook, time_machine): + time_machine.move_to(DEFAULT_DATE) + mock_hook.return_value.list_keys.return_value = {"a"} + assert not self.sensor.poke(dict()) + time_machine.coordinates.shift(10) + assert not self.sensor.poke(dict()) + time_machine.coordinates.shift(10) + assert self.sensor.poke(dict()) + + @pytest.mark.parametrize( + "soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException)) + ) + def test_fail_is_keys_unchanged(self, soft_fail, expected_exception): + op = S3KeysUnchangedSensor(task_id="sensor", bucket_name="test-bucket", prefix="test-prefix/path") + op.soft_fail = soft_fail + op.previous_objects = {"1", "2", "3"} + current_objects = {"1", "2"} + op.allow_delete = False + message = "Illegal behavior: objects were deleted in" + with pytest.raises(expected_exception, match=message): + op.is_keys_unchanged(current_objects=current_objects) + + @pytest.mark.parametrize( + "soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException)) + ) + def test_fail_execute_complete(self, soft_fail, expected_exception): + op = S3KeysUnchangedSensor(task_id="sensor", bucket_name="test-bucket", prefix="test-prefix/path") + op.soft_fail = soft_fail + message = "test message" + with pytest.raises(expected_exception, match=message): + op.execute_complete(context={}, event={"status": "error", "message": message}) From b54ec5a8bcbc01e65ac0c191de3d9da7d32fd250 Mon Sep 17 00:00:00 2001 From: Daniel Standish <15932138+dstandish@users.noreply.github.com> Date: Mon, 13 Nov 2023 09:44:25 -0800 Subject: [PATCH 2/3] remove added file --- .../aws/sensors/test_s3_keys_unchanged.py | 140 ------------------ 1 file changed, 140 deletions(-) delete mode 100644 tests/providers/amazon/aws/sensors/test_s3_keys_unchanged.py diff --git a/tests/providers/amazon/aws/sensors/test_s3_keys_unchanged.py b/tests/providers/amazon/aws/sensors/test_s3_keys_unchanged.py deleted file mode 100644 index a90b2234cbe51..0000000000000 --- a/tests/providers/amazon/aws/sensors/test_s3_keys_unchanged.py +++ /dev/null @@ -1,140 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -from __future__ import annotations - -from datetime import datetime -from unittest import mock - -import pytest -import time_machine - -from airflow.exceptions import AirflowException, AirflowSkipException -from airflow.models.dag import DAG -from airflow.providers.amazon.aws.sensors.s3 import S3KeysUnchangedSensor - -TEST_DAG_ID = "unit_tests_aws_sensor" -DEFAULT_DATE = datetime(2015, 1, 1) - - -class TestS3KeysUnchangedSensor: - def setup_method(self): - self.dag = DAG(f"{TEST_DAG_ID}test_schedule_dag_once", start_date=DEFAULT_DATE, schedule="@once") - - self.sensor = S3KeysUnchangedSensor( - task_id="sensor_1", - bucket_name="test-bucket", - prefix="test-prefix/path", - inactivity_period=12, - poke_interval=0.1, - min_objects=1, - allow_delete=True, - dag=self.dag, - ) - - def test_reschedule_mode_not_allowed(self): - with pytest.raises(ValueError): - S3KeysUnchangedSensor( - task_id="sensor_2", - bucket_name="test-bucket", - prefix="test-prefix/path", - poke_interval=0.1, - mode="reschedule", - dag=self.dag, - ) - - def test_render_template_fields(self): - S3KeysUnchangedSensor( - task_id="sensor_3", - bucket_name="test-bucket", - prefix="test-prefix/path", - inactivity_period=12, - poke_interval=0.1, - min_objects=1, - allow_delete=True, - dag=self.dag, - ).render_template_fields({}) - - @time_machine.travel(DEFAULT_DATE) - def test_files_deleted_between_pokes_throw_error(self): - self.sensor.allow_delete = False - self.sensor.is_keys_unchanged({"a", "b"}) - with pytest.raises(AirflowException): - self.sensor.is_keys_unchanged({"a"}) - - @pytest.mark.parametrize( - "current_objects, expected_returns, inactivity_periods", - [ - pytest.param( - ({"a"}, {"a", "b"}, {"a", "b", "c"}), - (False, False, False), - (0, 0, 0), - id="resetting inactivity period after key change", - ), - pytest.param( - ({"a", "b"}, {"a"}, {"a", "c"}), - (False, False, False), - (0, 0, 0), - id="item was deleted with option `allow_delete=True`", - ), - pytest.param( - ({"a"}, {"a"}, {"a"}), (False, False, True), (0, 10, 20), id="inactivity period was exceeded" - ), - pytest.param( - (set(), set(), set()), (False, False, False), (0, 10, 20), id="not pass if empty key is given" - ), - ], - ) - def test_key_changes(self, current_objects, expected_returns, inactivity_periods, time_machine): - time_machine.move_to(DEFAULT_DATE) - for current, expected, period in zip(current_objects, expected_returns, inactivity_periods): - assert self.sensor.is_keys_unchanged(current) == expected - assert self.sensor.inactivity_seconds == period - time_machine.coordinates.shift(10) - - @mock.patch("airflow.providers.amazon.aws.sensors.s3.S3Hook") - def test_poke_succeeds_on_upload_complete(self, mock_hook, time_machine): - time_machine.move_to(DEFAULT_DATE) - mock_hook.return_value.list_keys.return_value = {"a"} - assert not self.sensor.poke(dict()) - time_machine.coordinates.shift(10) - assert not self.sensor.poke(dict()) - time_machine.coordinates.shift(10) - assert self.sensor.poke(dict()) - - @pytest.mark.parametrize( - "soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException)) - ) - def test_fail_is_keys_unchanged(self, soft_fail, expected_exception): - op = S3KeysUnchangedSensor(task_id="sensor", bucket_name="test-bucket", prefix="test-prefix/path") - op.soft_fail = soft_fail - op.previous_objects = {"1", "2", "3"} - current_objects = {"1", "2"} - op.allow_delete = False - message = "Illegal behavior: objects were deleted in" - with pytest.raises(expected_exception, match=message): - op.is_keys_unchanged(current_objects=current_objects) - - @pytest.mark.parametrize( - "soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException)) - ) - def test_fail_execute_complete(self, soft_fail, expected_exception): - op = S3KeysUnchangedSensor(task_id="sensor", bucket_name="test-bucket", prefix="test-prefix/path") - op.soft_fail = soft_fail - message = "test message" - with pytest.raises(expected_exception, match=message): - op.execute_complete(context={}, event={"status": "error", "message": message}) From ada6e1b5074b33059ca86b727400cbd143792599 Mon Sep 17 00:00:00 2001 From: Daniel Standish <15932138+dstandish@users.noreply.github.com> Date: Mon, 13 Nov 2023 14:28:12 -0800 Subject: [PATCH 3/3] simplify imports --- airflow/models/dag.py | 1 - 1 file changed, 1 deletion(-) diff --git a/airflow/models/dag.py b/airflow/models/dag.py index 9abf9b0948edd..d693f41931797 100644 --- a/airflow/models/dag.py +++ b/airflow/models/dag.py @@ -19,7 +19,6 @@ import asyncio import collections -import collections.abc import copy import functools import itertools