diff --git a/airflow/providers/amazon/aws/hooks/s3.py b/airflow/providers/amazon/aws/hooks/s3.py index f4b493757d3ac..364aeafb81e6f 100644 --- a/airflow/providers/amazon/aws/hooks/s3.py +++ b/airflow/providers/amazon/aws/hooks/s3.py @@ -23,6 +23,7 @@ import gzip as gz import io import logging +import os import re import shutil import warnings @@ -632,6 +633,117 @@ def _is_in_period(input_date: datetime) -> bool: return [k["Key"] for k in keys if _is_in_period(k["LastModified"])] + async def is_keys_unchanged_async( + self, + client: AioBaseClient, + bucket_name: str, + prefix: str, + inactivity_period: float = 60 * 60, + min_objects: int = 1, + previous_objects: set[str] | None = None, + inactivity_seconds: int = 0, + allow_delete: bool = True, + last_activity_time: datetime | None = None, + ) -> dict[str, Any]: + """ + Checks whether new objects have been uploaded and the inactivity_period + has passed and updates the state of the sensor accordingly. + + :param client: aiobotocore client + :param bucket_name: the name of the bucket + :param prefix: a key prefix + :param inactivity_period: the total seconds of inactivity to designate + keys unchanged. Note, this mechanism is not real time and + this operator may not return until a poke_interval after this period + has passed with no additional objects sensed. + :param min_objects: the minimum number of objects needed for keys unchanged + sensor to be considered valid. + :param previous_objects: the set of object ids found during the last poke. + :param inactivity_seconds: number of inactive seconds + :param allow_delete: Should this sensor consider objects being deleted + between pokes valid behavior. If true a warning message will be logged + when this happens. If false an error will be raised. + :param last_activity_time: last activity datetime. + """ + if not previous_objects: + previous_objects = set() + list_keys = await self._list_keys_async(client=client, bucket_name=bucket_name, prefix=prefix) + current_objects = set(list_keys) + current_num_objects = len(current_objects) + if current_num_objects > len(previous_objects): + # When new objects arrived, reset the inactivity_seconds + # and update previous_objects for the next poke. + self.log.info( + "New objects found at %s, resetting last_activity_time.", + os.path.join(bucket_name, prefix), + ) + self.log.debug("New objects: %s", current_objects - previous_objects) + last_activity_time = datetime.now() + inactivity_seconds = 0 + previous_objects = current_objects + return { + "status": "pending", + "previous_objects": previous_objects, + "last_activity_time": last_activity_time, + "inactivity_seconds": inactivity_seconds, + } + + if len(previous_objects) - len(current_objects): + # During the last poke interval objects were deleted. + if allow_delete: + deleted_objects = previous_objects - current_objects + previous_objects = current_objects + last_activity_time = datetime.now() + self.log.info( + "Objects were deleted during the last poke interval. Updating the " + "file counter and resetting last_activity_time:\n%s", + deleted_objects, + ) + return { + "status": "pending", + "previous_objects": previous_objects, + "last_activity_time": last_activity_time, + "inactivity_seconds": inactivity_seconds, + } + + return { + "status": "error", + "message": f"{os.path.join(bucket_name, prefix)} between pokes.", + } + + if last_activity_time: + inactivity_seconds = int((datetime.now() - last_activity_time).total_seconds()) + else: + # Handles the first poke where last inactivity time is None. + last_activity_time = datetime.now() + inactivity_seconds = 0 + + if inactivity_seconds >= inactivity_period: + path = os.path.join(bucket_name, prefix) + + if current_num_objects >= min_objects: + success_message = ( + f"SUCCESS: Sensor found {current_num_objects} objects at {path}. " + "Waited at least {inactivity_period} seconds, with no new objects uploaded." + ) + self.log.info(success_message) + return { + "status": "success", + "message": success_message, + } + + self.log.error("FAILURE: Inactivity Period passed, not enough objects found in %s", path) + return { + "status": "error", + "message": f"FAILURE: Inactivity Period passed, not enough objects found in {path}", + } + return { + "status": "pending", + "previous_objects": previous_objects, + "last_activity_time": last_activity_time, + "inactivity_seconds": inactivity_seconds, + } + @provide_bucket_name def list_keys( self, diff --git a/airflow/providers/amazon/aws/sensors/s3.py b/airflow/providers/amazon/aws/sensors/s3.py index 23b344fa51621..7192585afd58b 100644 --- a/airflow/providers/amazon/aws/sensors/s3.py +++ b/airflow/providers/amazon/aws/sensors/s3.py @@ -31,7 +31,7 @@ from airflow.exceptions import AirflowException from airflow.providers.amazon.aws.hooks.s3 import S3Hook -from airflow.providers.amazon.aws.triggers.s3 import S3KeyTrigger +from airflow.providers.amazon.aws.triggers.s3 import S3KeysUnchangedTrigger, S3KeyTrigger from airflow.sensors.base import BaseSensorOperator, poke_mode_only @@ -222,6 +222,7 @@ class S3KeysUnchangedSensor(BaseSensorOperator): :param allow_delete: Should this sensor consider objects being deleted between pokes valid behavior. If true a warning message will be logged when this happens. If false an error will be raised. + :param deferrable: Run sensor in the deferrable mode """ template_fields: Sequence[str] = ("bucket_name", "prefix") @@ -237,6 +238,7 @@ def __init__( min_objects: int = 1, previous_objects: set[str] | None = None, allow_delete: bool = True, + deferrable: bool = False, **kwargs, ) -> None: @@ -251,6 +253,7 @@ def __init__( self.previous_objects = previous_objects or set() self.inactivity_seconds = 0 self.allow_delete = allow_delete + self.deferrable = deferrable self.aws_conn_id = aws_conn_id self.verify = verify self.last_activity_time: datetime | None = None @@ -325,3 +328,36 @@ def is_keys_unchanged(self, current_objects: set[str]) -> bool: def poke(self, context: Context): return self.is_keys_unchanged(set(self.hook.list_keys(self.bucket_name, prefix=self.prefix))) + + def execute(self, context: Context) -> None: + """Airflow runs this method on the worker and defers using the trigger if deferrable is True.""" + if not self.deferrable: + super().execute(context) + else: + if not self.poke(context): + self.defer( + timeout=timedelta(seconds=self.timeout), + trigger=S3KeysUnchangedTrigger( + bucket_name=self.bucket_name, + prefix=self.prefix, + inactivity_period=self.inactivity_period, + min_objects=self.min_objects, + previous_objects=self.previous_objects, + inactivity_seconds=self.inactivity_seconds, + allow_delete=self.allow_delete, + aws_conn_id=self.aws_conn_id, + verify=self.verify, + last_activity_time=self.last_activity_time, + ), + method_name="execute_complete", + ) + + def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> None: + """ + Callback for when the trigger fires - returns immediately. + Relies on trigger to throw an exception, otherwise it assumes execution was + successful. + """ + if event and event["status"] == "error": + raise AirflowException(event["message"]) + return None diff --git a/airflow/providers/amazon/aws/triggers/s3.py b/airflow/providers/amazon/aws/triggers/s3.py index 2e89de81f2a5c..06432108eb35f 100644 --- a/airflow/providers/amazon/aws/triggers/s3.py +++ b/airflow/providers/amazon/aws/triggers/s3.py @@ -17,6 +17,7 @@ from __future__ import annotations import asyncio +from datetime import datetime from functools import cached_property from typing import Any, AsyncIterator @@ -97,3 +98,109 @@ async def run(self) -> AsyncIterator[TriggerEvent]: except Exception as e: yield TriggerEvent({"status": "error", "message": str(e)}) + + +class S3KeysUnchangedTrigger(BaseTrigger): + """ + S3KeysUnchangedTrigger is fired as deferred class with params to run the task in trigger worker. + + :param bucket_name: Name of the S3 bucket. Only needed when ``bucket_key`` + is not provided as a full s3:// url. + :param prefix: The prefix being waited on. Relative path from bucket root level. + :param inactivity_period: The total seconds of inactivity to designate + keys unchanged. Note, this mechanism is not real time and + this operator may not return until a poke_interval after this period + has passed with no additional objects sensed. + :param min_objects: The minimum number of objects needed for keys unchanged + sensor to be considered valid. + :param inactivity_seconds: reference to the seconds of inactivity + :param previous_objects: The set of object ids found during the last poke. + :param allow_delete: Should this sensor consider objects being deleted + :param aws_conn_id: reference to the s3 connection + :param last_activity_time: last modified or last active time + :param verify: Whether or not to verify SSL certificates for S3 connection. + By default SSL certificates are verified. + :param hook_params: params for hook its optional + """ + + def __init__( + self, + bucket_name: str, + prefix: str, + inactivity_period: float = 60 * 60, + min_objects: int = 1, + inactivity_seconds: int = 0, + previous_objects: set[str] | None = None, + allow_delete: bool = True, + aws_conn_id: str = "aws_default", + last_activity_time: datetime | None = None, + verify: bool | str | None = None, + **hook_params: Any, + ): + super().__init__() + self.bucket_name = bucket_name + self.prefix = prefix + if inactivity_period < 0: + raise ValueError("inactivity_period must be non-negative") + if not previous_objects: + previous_objects = set() + self.inactivity_period = inactivity_period + self.min_objects = min_objects + self.previous_objects = previous_objects + self.inactivity_seconds = inactivity_seconds + self.allow_delete = allow_delete + self.aws_conn_id = aws_conn_id + self.last_activity_time = last_activity_time + self.verify = verify + self.polling_period_seconds = 0 + self.hook_params = hook_params + + def serialize(self) -> tuple[str, dict[str, Any]]: + """Serialize S3KeysUnchangedTrigger arguments and classpath.""" + return ( + "airflow.providers.amazon.aws.triggers.s3.S3KeysUnchangedTrigger", + { + "bucket_name": self.bucket_name, + "prefix": self.prefix, + "inactivity_period": self.inactivity_period, + "min_objects": self.min_objects, + "previous_objects": self.previous_objects, + "inactivity_seconds": self.inactivity_seconds, + "allow_delete": self.allow_delete, + "aws_conn_id": self.aws_conn_id, + "last_activity_time": self.last_activity_time, + "hook_params": self.hook_params, + "verify": self.verify, + "polling_period_seconds": self.polling_period_seconds, + }, + ) + + @cached_property + def hook(self) -> S3Hook: + return S3Hook(aws_conn_id=self.aws_conn_id, verify=self.hook_params.get("verify")) + + async def run(self) -> AsyncIterator[TriggerEvent]: + """Make an asynchronous connection using S3Hook.""" + try: + async with self.hook.async_conn as client: + while True: + result = await self.hook.is_keys_unchanged_async( + client=client, + bucket_name=self.bucket_name, + prefix=self.prefix, + inactivity_period=self.inactivity_period, + min_objects=self.min_objects, + previous_objects=self.previous_objects, + inactivity_seconds=self.inactivity_seconds, + allow_delete=self.allow_delete, + last_activity_time=self.last_activity_time, + ) + if result.get("status") == "success" or result.get("status") == "error": + yield TriggerEvent(result) + elif result.get("status") == "pending": + self.previous_objects = result.get("previous_objects", set()) + self.last_activity_time = result.get("last_activity_time") + self.inactivity_seconds = result.get("inactivity_seconds", 0) + await asyncio.sleep(self.polling_period_seconds) + except Exception as e: + yield TriggerEvent({"status": "error", "message": str(e)}) diff --git a/docs/apache-airflow-providers-amazon/operators/s3/s3.rst b/docs/apache-airflow-providers-amazon/operators/s3/s3.rst index 2d8558963e504..72e90c50caa25 100644 --- a/docs/apache-airflow-providers-amazon/operators/s3/s3.rst +++ b/docs/apache-airflow-providers-amazon/operators/s3/s3.rst @@ -285,6 +285,16 @@ as the state of the listed objects in the Amazon S3 bucket will be lost between :start-after: [START howto_sensor_s3_keys_unchanged] :end-before: [END howto_sensor_s3_keys_unchanged] +You can also run this sensor in deferrable mode by setting the parameter ``deferrable`` to True. +This will lead to efficient utilization of Airflow workers as polling for job status happens on +the triggerer asynchronously. Note that this will need triggerer to be available on your Airflow deployment. + +.. exampleinclude:: /../../tests/system/providers/amazon/aws/example_s3.py + :language: python + :dedent: 4 + :start-after: [START howto_sensor_s3_keys_unchanged_deferrable] + :end-before: [END howto_sensor_s3_keys_unchanged_deferrable] + Reference --------- diff --git a/tests/providers/amazon/aws/hooks/test_s3.py b/tests/providers/amazon/aws/hooks/test_s3.py index 267ff665c6d05..072137dc5f710 100644 --- a/tests/providers/amazon/aws/hooks/test_s3.py +++ b/tests/providers/amazon/aws/hooks/test_s3.py @@ -692,6 +692,127 @@ async def test_s3__check_key_with_wild_card_async( ) assert response is False + @pytest.mark.asyncio + @async_mock.patch("airflow.providers.amazon.aws.triggers.s3.S3Hook.async_conn") + @async_mock.patch("airflow.providers.amazon.aws.triggers.s3.S3Hook._list_keys_async") + async def test_s3_key_hook_is_keys_unchanged_false_async(self, mock_list_keys, mock_client): + """ + Test is_key_unchanged gives False response when the key value is unchanged in specified period. + """ + + mock_list_keys.return_value = ["test"] + + s3_hook_async = S3Hook(client_type="S3", resource_type="S3") + response = await s3_hook_async.is_keys_unchanged_async( + client=mock_client.return_value, + bucket_name="test_bucket", + prefix="test", + inactivity_period=1, + min_objects=1, + previous_objects=set(), + inactivity_seconds=0, + allow_delete=True, + last_activity_time=None, + ) + + assert response.get("status") == "pending" + + # test for the case when current_objects < previous_objects + mock_list_keys.return_value = [] + + s3_hook_async = S3Hook(client_type="S3", resource_type="S3") + response = await s3_hook_async.is_keys_unchanged_async( + client=mock_client.return_value, + bucket_name="test_bucket", + prefix="test", + inactivity_period=1, + min_objects=1, + previous_objects=set("test"), + inactivity_seconds=0, + allow_delete=True, + last_activity_time=None, + ) + + assert response.get("status") == "pending" + + @pytest.mark.asyncio + @async_mock.patch("airflow.providers.amazon.aws.triggers.s3.S3Hook.async_conn") + @async_mock.patch("airflow.providers.amazon.aws.triggers.s3.S3Hook._list_keys_async") + async def test_s3_key_hook_is_keys_unchanged_exception_async(self, mock_list_keys, mock_client): + """ + Test is_key_unchanged gives AirflowException. + """ + mock_list_keys.return_value = [] + + s3_hook_async = S3Hook(client_type="S3", resource_type="S3") + + response = await s3_hook_async.is_keys_unchanged_async( + client=mock_client.return_value, + bucket_name="test_bucket", + prefix="test", + inactivity_period=1, + min_objects=1, + previous_objects=set("test"), + inactivity_seconds=0, + allow_delete=False, + last_activity_time=None, + ) + + assert response == {"message": "test_bucket/test between pokes.", "status": "error"} + + @pytest.mark.asyncio + @async_mock.patch("airflow.providers.amazon.aws.triggers.s3.S3Hook.async_conn") + @async_mock.patch("airflow.providers.amazon.aws.triggers.s3.S3Hook._list_keys_async") + async def test_s3_key_hook_is_keys_unchanged_pending_async(self, mock_list_keys, mock_client): + """ + Test is_key_unchanged gives AirflowException. + """ + mock_list_keys.return_value = [] + + s3_hook_async = S3Hook(client_type="S3", resource_type="S3") + + response = await s3_hook_async.is_keys_unchanged_async( + client=mock_client.return_value, + bucket_name="test_bucket", + prefix="test", + inactivity_period=1, + min_objects=0, + previous_objects=set(), + inactivity_seconds=0, + allow_delete=False, + last_activity_time=None, + ) + + assert response.get("status") == "pending" + + @pytest.mark.asyncio + @async_mock.patch("airflow.providers.amazon.aws.triggers.s3.S3Hook.async_conn") + @async_mock.patch("airflow.providers.amazon.aws.triggers.s3.S3Hook._list_keys_async") + async def test_s3_key_hook_is_keys_unchanged_inactivity_error_async(self, mock_list_keys, mock_client): + """ + Test is_key_unchanged gives AirflowException. + """ + mock_list_keys.return_value = [] + + s3_hook_async = S3Hook(client_type="S3", resource_type="S3") + + response = await s3_hook_async.is_keys_unchanged_async( + client=mock_client.return_value, + bucket_name="test_bucket", + prefix="test", + inactivity_period=0, + min_objects=5, + previous_objects=set(), + inactivity_seconds=5, + allow_delete=False, + last_activity_time=None, + ) + + assert response == { + "status": "error", + "message": "FAILURE: Inactivity Period passed, not enough objects found in test_bucket/test", + } + def test_load_bytes(self, s3_bucket): hook = S3Hook() hook.load_bytes(b"Content", "my_key", s3_bucket) diff --git a/tests/providers/amazon/aws/triggers/test_s3.py b/tests/providers/amazon/aws/triggers/test_s3.py index 08bba39b08a08..a73223f15ce89 100644 --- a/tests/providers/amazon/aws/triggers/test_s3.py +++ b/tests/providers/amazon/aws/triggers/test_s3.py @@ -17,13 +17,13 @@ from __future__ import annotations import asyncio +from datetime import datetime from unittest import mock as async_mock import pytest -from airflow.providers.amazon.aws.triggers.s3 import ( - S3KeyTrigger, -) +from airflow.providers.amazon.aws.triggers.s3 import S3KeysUnchangedTrigger, S3KeyTrigger +from airflow.triggers.base import TriggerEvent class TestS3KeyTrigger: @@ -75,3 +75,82 @@ async def test_run_pending(self, mock_client, mock_check_key_async): assert task.done() is False asyncio.get_event_loop().stop() + + +class TestS3KeysUnchangedTrigger: + def test_serialization(self): + """ + Asserts that the S3KeysUnchangedTrigger correctly serializes its arguments + and classpath. + """ + trigger = S3KeysUnchangedTrigger( + bucket_name="test_bucket", + prefix="test", + inactivity_period=1, + min_objects=1, + inactivity_seconds=0, + previous_objects=None, + ) + classpath, kwargs = trigger.serialize() + assert classpath == "airflow.providers.amazon.aws.triggers.s3.S3KeysUnchangedTrigger" + assert kwargs == { + "bucket_name": "test_bucket", + "prefix": "test", + "inactivity_period": 1, + "min_objects": 1, + "inactivity_seconds": 0, + "previous_objects": set(), + "allow_delete": True, + "aws_conn_id": "aws_default", + "last_activity_time": None, + "hook_params": {}, + "verify": None, + "polling_period_seconds": 0, + } + + @pytest.mark.asyncio + @async_mock.patch("airflow.providers.amazon.aws.triggers.s3.S3Hook.async_conn") + async def test_run_wait(self, mock_client): + """Test if the task is run in trigger successfully.""" + mock_client.return_value.check_key.return_value = True + trigger = S3KeysUnchangedTrigger(bucket_name="test_bucket", prefix="test") + with mock_client: + task = asyncio.create_task(trigger.run().__anext__()) + await asyncio.sleep(0.5) + + assert task.done() is True + asyncio.get_event_loop().stop() + + def test_run_raise_value_error(self): + """ + Test if the S3KeysUnchangedTrigger raises Value error for negative inactivity_period. + """ + with pytest.raises(ValueError): + S3KeysUnchangedTrigger(bucket_name="test_bucket", prefix="test", inactivity_period=-100) + + @pytest.mark.asyncio + @async_mock.patch("airflow.providers.amazon.aws.triggers.s3.S3Hook.async_conn") + @async_mock.patch("airflow.providers.amazon.aws.triggers.s3.S3Hook.is_keys_unchanged_async") + async def test_run_success(self, mock_is_keys_unchanged, mock_client): + """ + Test if the task is run in triggerer successfully. + """ + mock_is_keys_unchanged.return_value = {"status": "success"} + trigger = S3KeysUnchangedTrigger(bucket_name="test_bucket", prefix="test") + generator = trigger.run() + actual = await generator.asend(None) + assert TriggerEvent({"status": "success"}) == actual + + @pytest.mark.asyncio + @async_mock.patch("airflow.providers.amazon.aws.triggers.s3.S3Hook.async_conn") + @async_mock.patch("airflow.providers.amazon.aws.triggers.s3.S3Hook.is_keys_unchanged_async") + async def test_run_pending(self, mock_is_keys_unchanged, mock_client): + """Test if the task is run in triggerer successfully.""" + mock_is_keys_unchanged.return_value = {"status": "pending", "last_activity_time": datetime.now()} + trigger = S3KeysUnchangedTrigger(bucket_name="test_bucket", prefix="test") + task = asyncio.create_task(trigger.run().__anext__()) + await asyncio.sleep(0.5) + + # TriggerEvent was not returned + assert task.done() is False + asyncio.get_event_loop().stop() diff --git a/tests/system/providers/amazon/aws/example_s3.py b/tests/system/providers/amazon/aws/example_s3.py index 87806a0bf0a5a..484b47fba068e 100644 --- a/tests/system/providers/amazon/aws/example_s3.py +++ b/tests/system/providers/amazon/aws/example_s3.py @@ -244,10 +244,20 @@ def check_fn(files: list) -> bool: task_id="sensor_keys_unchanged", bucket_name=bucket_name_2, prefix=PREFIX, - inactivity_period=10, # inactivity_period in seconds + inactivity_period=10, ) # [END howto_sensor_s3_keys_unchanged] + # [START howto_sensor_s3_keys_unchanged_deferrable] + sensor_keys_unchanged = S3KeysUnchangedSensor( + task_id="sensor_keys_unchanged", + bucket_name=bucket_name_2, + prefix=PREFIX, + inactivity_period=10, # inactivity_period in seconds + deferrable=True, + ) + # [END howto_sensor_s3_keys_unchanged_deferrable] + # [START howto_operator_s3_delete_objects] delete_objects = S3DeleteObjectsOperator( task_id="delete_objects",