diff --git a/airflow/providers/amazon/aws/hooks/s3.py b/airflow/providers/amazon/aws/hooks/s3.py index 5b8dd55467ab7..0489fa63128ad 100644 --- a/airflow/providers/amazon/aws/hooks/s3.py +++ b/airflow/providers/amazon/aws/hooks/s3.py @@ -18,6 +18,7 @@ """Interact with AWS S3, using the boto3 library.""" from __future__ import annotations +import asyncio import fnmatch import gzip as gz import io @@ -38,6 +39,13 @@ from urllib.parse import urlsplit from uuid import uuid4 +if TYPE_CHECKING: + try: + from aiobotocore.client import AioBaseClient + except ImportError: + pass + +from asgiref.sync import sync_to_async from boto3.s3.transfer import TransferConfig from botocore.exceptions import ClientError @@ -88,6 +96,29 @@ def wrapper(*args, **kwargs) -> T: return cast(T, wrapper) +def provide_bucket_name_async(func: T) -> T: + """ + Function decorator that provides a bucket name taken from the connection + in case no bucket name has been passed to the function. + """ + function_signature = signature(func) + + @wraps(func) + async def wrapper(*args: Any, **kwargs: Any) -> Any: + bound_args = function_signature.bind(*args, **kwargs) + + if "bucket_name" not in bound_args.arguments: + self = args[0] + if self.aws_conn_id: + connection = await sync_to_async(self.get_connection)(self.aws_conn_id) + if connection.schema: + bound_args.arguments["bucket_name"] = connection.schema + + return await func(*bound_args.args, **bound_args.kwargs) + + return cast(T, wrapper) + + def unify_bucket_name_and_key(func: T) -> T: """ Function decorator that unifies bucket name and key taken from the key @@ -228,7 +259,6 @@ def get_s3_bucket_key( f"If `{bucket_param_name}` is provided, {key_param_name} should be a relative path " "from root level, rather than a full s3:// url" ) - return bucket, key @provide_bucket_name @@ -363,6 +393,226 @@ def list_prefixes( return prefixes + @provide_bucket_name_async + @unify_bucket_name_and_key + async def get_head_object_async( + self, client: AioBaseClient, key: str, bucket_name: str | None = None + ) -> dict[str, Any] | None: + """ + Retrieves metadata of an object. + + :param client: aiobotocore client + :param bucket_name: Name of the bucket in which the file is stored + :param key: S3 key that will point to the file + """ + head_object_val: dict[str, Any] | None = None + try: + head_object_val = await client.head_object(Bucket=bucket_name, Key=key) + return head_object_val + except ClientError as e: + if e.response["ResponseMetadata"]["HTTPStatusCode"] == 404: + return head_object_val + else: + raise e + + async def list_prefixes_async( + self, + client: AioBaseClient, + bucket_name: str | None = None, + prefix: str | None = None, + delimiter: str | None = None, + page_size: int | None = None, + max_items: int | None = None, + ) -> list[Any]: + """ + Lists prefixes in a bucket under prefix. + + :param client: ClientCreatorContext + :param bucket_name: the name of the bucket + :param prefix: a key prefix + :param delimiter: the delimiter marks key hierarchy. + :param page_size: pagination size + :param max_items: maximum items to return + :return: a list of matched prefixes + """ + prefix = prefix or "" + delimiter = delimiter or "" + config = { + "PageSize": page_size, + "MaxItems": max_items, + } + + paginator = client.get_paginator("list_objects_v2") + response = paginator.paginate( + Bucket=bucket_name, Prefix=prefix, Delimiter=delimiter, PaginationConfig=config + ) + + prefixes = [] + async for page in response: + if "CommonPrefixes" in page: + for common_prefix in page["CommonPrefixes"]: + prefixes.append(common_prefix["Prefix"]) + + return prefixes + + @provide_bucket_name_async + async def get_file_metadata_async(self, client: AioBaseClient, bucket_name: str, key: str) -> list[Any]: + """ + Gets a list of files that a key matching a wildcard expression exists in a bucket asynchronously. + + :param client: aiobotocore client + :param bucket_name: the name of the bucket + :param key: the path to the key + """ + prefix = re.split(r"[\[\*\?]", key, 1)[0] + delimiter = "" + paginator = client.get_paginator("list_objects_v2") + response = paginator.paginate(Bucket=bucket_name, Prefix=prefix, Delimiter=delimiter) + files = [] + async for page in response: + if "Contents" in page: + files += page["Contents"] + return files + + async def _check_key_async( + self, + client: AioBaseClient, + bucket_val: str, + wildcard_match: bool, + key: str, + ) -> bool: + """ + Function to check if wildcard_match is True get list of files that a key matching a wildcard + expression exists in a bucket asynchronously and return the boolean value. If wildcard_match + is False get the head object from the bucket and return the boolean value. + + :param client: aiobotocore client + :param bucket_val: the name of the bucket + :param key: S3 keys that will point to the file + :param wildcard_match: the path to the key + """ + bucket_name, key = self.get_s3_bucket_key(bucket_val, key, "bucket_name", "bucket_key") + if wildcard_match: + keys = await self.get_file_metadata_async(client, bucket_name, key) + key_matches = [k for k in keys if fnmatch.fnmatch(k["Key"], key)] + if len(key_matches) == 0: + return False + else: + obj = await self.get_head_object_async(client, key, bucket_name) + if obj is None: + return False + + return True + + async def check_key_async( + self, + client: AioBaseClient, + bucket: str, + bucket_keys: str | list[str], + wildcard_match: bool, + ) -> bool: + """ + Checks for all keys in bucket and returns boolean value. + + :param client: aiobotocore client + :param bucket: the name of the bucket + :param bucket_keys: S3 keys that will point to the file + :param wildcard_match: the path to the key + """ + if isinstance(bucket_keys, list): + return all( + await asyncio.gather( + *(self._check_key_async(client, bucket, wildcard_match, key) for key in bucket_keys) + ) + ) + return await self._check_key_async(client, bucket, wildcard_match, bucket_keys) + + async def check_for_prefix_async( + self, client: AioBaseClient, prefix: str, delimiter: str, bucket_name: str | None = None + ) -> bool: + """ + Checks that a prefix exists in a bucket. + + :param bucket_name: the name of the bucket + :param prefix: a key prefix + :param delimiter: the delimiter marks key hierarchy. + :return: False if the prefix does not exist in the bucket and True if it does. + """ + prefix = prefix + delimiter if prefix[-1] != delimiter else prefix + prefix_split = re.split(rf"(\w+[{delimiter}])$", prefix, 1) + previous_level = prefix_split[0] + plist = await self.list_prefixes_async(client, bucket_name, previous_level, delimiter) + return prefix in plist + + async def _check_for_prefix_async( + self, client: AioBaseClient, prefix: str, delimiter: str, bucket_name: str | None = None + ) -> bool: + return await self.check_for_prefix_async( + client, prefix=prefix, delimiter=delimiter, bucket_name=bucket_name + ) + + async def get_files_async( + self, + client: AioBaseClient, + bucket: str, + bucket_keys: str | list[str], + wildcard_match: bool, + delimiter: str | None = "/", + ) -> list[Any]: + """Gets a list of files in the bucket.""" + keys: list[Any] = [] + for key in bucket_keys: + prefix = key + if wildcard_match: + prefix = re.split(r"[\[\*\?]", key, 1)[0] + + paginator = client.get_paginator("list_objects_v2") + response = paginator.paginate(Bucket=bucket, Prefix=prefix, Delimiter=delimiter) + async for page in response: + if "Contents" in page: + _temp = [k for k in page["Contents"] if isinstance(k.get("Size", None), (int, float))] + keys = keys + _temp + return keys + + @staticmethod + async def _list_keys_async( + client: AioBaseClient, + bucket_name: str | None = None, + prefix: str | None = None, + delimiter: str | None = None, + page_size: int | None = None, + max_items: int | None = None, + ) -> list[str]: + """ + Lists keys in a bucket under prefix and not containing delimiter. + + :param bucket_name: the name of the bucket + :param prefix: a key prefix + :param delimiter: the delimiter marks key hierarchy. + :param page_size: pagination size + :param max_items: maximum items to return + :return: a list of matched keys + """ + prefix = prefix or "" + delimiter = delimiter or "" + config = { + "PageSize": page_size, + "MaxItems": max_items, + } + + paginator = client.get_paginator("list_objects_v2") + response = paginator.paginate( + Bucket=bucket_name, Prefix=prefix, Delimiter=delimiter, PaginationConfig=config + ) + + keys = [] + async for page in response: + if "Contents" in page: + for k in page["Contents"]: + keys.append(k["Key"]) + + return keys + def _list_key_object_filter( self, keys: list, from_datetime: datetime | None = None, to_datetime: datetime | None = None ) -> list: diff --git a/airflow/providers/amazon/aws/sensors/s3.py b/airflow/providers/amazon/aws/sensors/s3.py index 7e46b6c91a8f4..9eb1ab75d524d 100644 --- a/airflow/providers/amazon/aws/sensors/s3.py +++ b/airflow/providers/amazon/aws/sensors/s3.py @@ -20,9 +20,9 @@ import fnmatch import os import re -from datetime import datetime +from datetime import datetime, timedelta from functools import cached_property -from typing import TYPE_CHECKING, Callable, Sequence +from typing import TYPE_CHECKING, Any, Callable, Sequence, cast from deprecated import deprecated @@ -31,6 +31,7 @@ from airflow.exceptions import AirflowException from airflow.providers.amazon.aws.hooks.s3 import S3Hook +from airflow.providers.amazon.aws.triggers.s3 import S3KeyTrigger from airflow.sensors.base import BaseSensorOperator, poke_mode_only @@ -48,7 +49,7 @@ class S3KeySensor(BaseSensorOperator): or relative path from root level. When it's specified as a full s3:// url, please leave bucket_name as `None` :param bucket_name: Name of the S3 bucket. Only needed when ``bucket_key`` - is not provided as a full s3:// url. When specified, all the keys passed to ``bucket_key`` + is not provided as a full ``s3://`` url. When specified, all the keys passed to ``bucket_key`` refers to this bucket :param wildcard_match: whether the bucket_key should be interpreted as a Unix wildcard pattern @@ -61,8 +62,9 @@ class S3KeySensor(BaseSensorOperator): def check_fn(files: List) -> bool: return any(f.get('Size', 0) > 1048576 for f in files) :param aws_conn_id: a reference to the s3 connection - :param verify: Whether or not to verify SSL certificates for S3 connection. - By default SSL certificates are verified. + :param deferrable: Run operator in the deferrable mode + :param verify: Whether to verify SSL certificates for S3 connection. + By default, SSL certificates are verified. You can provide the following values: - ``False``: do not validate SSL certificates. SSL will still be used @@ -84,6 +86,7 @@ def __init__( check_fn: Callable[..., bool] | None = None, aws_conn_id: str = "aws_default", verify: str | bool | None = None, + deferrable: bool = False, **kwargs, ): super().__init__(**kwargs) @@ -93,6 +96,7 @@ def __init__( self.check_fn = check_fn self.aws_conn_id = aws_conn_id self.verify = verify + self.deferrable = deferrable def _check_key(self, key): bucket_name, key = S3Hook.get_s3_bucket_key(self.bucket_name, key, "bucket_name", "bucket_key") @@ -131,6 +135,47 @@ def poke(self, context: Context): else: return all(self._check_key(key) for key in self.bucket_key) + def execute(self, context: Context) -> None: + """Airflow runs this method on the worker and defers using the trigger.""" + if not self.deferrable: + super().execute(context) + else: + if not self.poke(context=context): + self._defer() + + def _defer(self) -> None: + """Check for a keys in s3 and defers using the triggerer.""" + self.defer( + timeout=timedelta(seconds=self.timeout), + trigger=S3KeyTrigger( + bucket_name=cast(str, self.bucket_name), + bucket_key=self.bucket_key, + wildcard_match=self.wildcard_match, + aws_conn_id=self.aws_conn_id, + verify=self.verify, + poke_interval=self.poke_interval, + should_check_fn=True if self.check_fn else False, + ), + method_name="execute_complete", + ) + + def execute_complete(self, context: Context, event: dict[str, Any]) -> bool | None: + """ + Callback for when the trigger fires - returns immediately. + Relies on trigger to throw an exception, otherwise it assumes execution was + successful. + """ + if event["status"] == "running": + found_keys = self.check_fn(event["files"]) # type: ignore[misc] + if found_keys: + return None + else: + self._defer() + + if event["status"] == "error": + raise AirflowException(event["message"]) + return None + @deprecated(reason="use `hook` property instead.") def get_hook(self) -> S3Hook: """Create and return an S3Hook.""" diff --git a/airflow/providers/amazon/aws/triggers/s3.py b/airflow/providers/amazon/aws/triggers/s3.py new file mode 100644 index 0000000000000..2e89de81f2a5c --- /dev/null +++ b/airflow/providers/amazon/aws/triggers/s3.py @@ -0,0 +1,99 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import asyncio +from functools import cached_property +from typing import Any, AsyncIterator + +from airflow.providers.amazon.aws.hooks.s3 import S3Hook +from airflow.triggers.base import BaseTrigger, TriggerEvent + + +class S3KeyTrigger(BaseTrigger): + """ + S3KeyTrigger is fired as deferred class with params to run the task in trigger worker. + + :param bucket_name: Name of the S3 bucket. Only needed when ``bucket_key`` + is not provided as a full s3:// url. + :param bucket_key: The key being waited on. Supports full s3:// style url + or relative path from root level. When it's specified as a full s3:// + url, please leave bucket_name as `None`. + :param wildcard_match: whether the bucket_key should be interpreted as a + Unix wildcard pattern + :param aws_conn_id: reference to the s3 connection + :param hook_params: params for hook its optional + """ + + def __init__( + self, + bucket_name: str, + bucket_key: str | list[str], + wildcard_match: bool = False, + aws_conn_id: str = "aws_default", + poke_interval: float = 5.0, + should_check_fn: bool = False, + **hook_params: Any, + ): + super().__init__() + self.bucket_name = bucket_name + self.bucket_key = bucket_key + self.wildcard_match = wildcard_match + self.aws_conn_id = aws_conn_id + self.hook_params = hook_params + self.poke_interval = poke_interval + self.should_check_fn = should_check_fn + + def serialize(self) -> tuple[str, dict[str, Any]]: + """Serialize S3KeyTrigger arguments and classpath.""" + return ( + "airflow.providers.amazon.aws.triggers.s3.S3KeyTrigger", + { + "bucket_name": self.bucket_name, + "bucket_key": self.bucket_key, + "wildcard_match": self.wildcard_match, + "aws_conn_id": self.aws_conn_id, + "hook_params": self.hook_params, + "poke_interval": self.poke_interval, + "should_check_fn": self.should_check_fn, + }, + ) + + @cached_property + def hook(self) -> S3Hook: + return S3Hook(aws_conn_id=self.aws_conn_id, verify=self.hook_params.get("verify")) + + async def run(self) -> AsyncIterator[TriggerEvent]: + """Make an asynchronous connection using S3HookAsync.""" + try: + async with self.hook.async_conn as client: + while True: + if await self.hook.check_key_async( + client, self.bucket_name, self.bucket_key, self.wildcard_match + ): + if self.should_check_fn: + s3_objects = await self.hook.get_files_async( + client, self.bucket_name, self.bucket_key, self.wildcard_match + ) + await asyncio.sleep(self.poke_interval) + yield TriggerEvent({"status": "running", "files": s3_objects}) + else: + yield TriggerEvent({"status": "success"}) + await asyncio.sleep(self.poke_interval) + + except Exception as e: + yield TriggerEvent({"status": "error", "message": str(e)}) diff --git a/airflow/providers/amazon/provider.yaml b/airflow/providers/amazon/provider.yaml index 9725a55bbaa15..05924eebc9742 100644 --- a/airflow/providers/amazon/provider.yaml +++ b/airflow/providers/amazon/provider.yaml @@ -73,9 +73,9 @@ dependencies: - mypy-boto3-rds>=1.24.0 - mypy-boto3-redshift-data>=1.24.0 - mypy-boto3-appflow>=1.24.0 + - asgiref - mypy-boto3-s3>=1.24.0 - integrations: - integration-name: Amazon Athena external-doc-url: https://aws.amazon.com/athena/ @@ -522,6 +522,9 @@ triggers: python-modules: - airflow.providers.amazon.aws.triggers.glue - airflow.providers.amazon.aws.triggers.glue_crawler + - integration-name: Amazon Simple Storage Service (S3) + python-modules: + - airflow.providers.amazon.aws.triggers.s3 - integration-name: Amazon EMR python-modules: - airflow.providers.amazon.aws.triggers.emr diff --git a/docs/apache-airflow-providers-amazon/operators/s3/s3.rst b/docs/apache-airflow-providers-amazon/operators/s3/s3.rst index ef9d0864490f9..2d8558963e504 100644 --- a/docs/apache-airflow-providers-amazon/operators/s3/s3.rst +++ b/docs/apache-airflow-providers-amazon/operators/s3/s3.rst @@ -248,6 +248,26 @@ multiple files can match one key. The list of matched S3 object attributes conta :start-after: [START howto_sensor_s3_key_function] :end-before: [END howto_sensor_s3_key_function] +You can also run this operator in deferrable mode by setting the parameter ``deferrable`` to True. +This will lead to efficient utilization of Airflow workers as polling for job status happens on +the triggerer asynchronously. Note that this will need triggerer to be available on your Airflow deployment. + +To check one file: + +.. exampleinclude:: /../../tests/system/providers/amazon/aws/example_s3.py + :language: python + :dedent: 4 + :start-after: [START howto_sensor_s3_key_single_key_deferrable] + :end-before: [END howto_sensor_s3_key_single_key_deferrable] + +To check multiple files: + +.. exampleinclude:: /../../tests/system/providers/amazon/aws/example_s3.py + :language: python + :dedent: 4 + :start-after: [START howto_sensor_s3_key_multiple_keys_deferrable] + :end-before: [END howto_sensor_s3_key_multiple_keys_deferrable] + .. _howto/sensor:S3KeysUnchangedSensor: Wait on Amazon S3 prefix changes diff --git a/generated/provider_dependencies.json b/generated/provider_dependencies.json index 57d35a8500c56..2fcc456d6cb2c 100644 --- a/generated/provider_dependencies.json +++ b/generated/provider_dependencies.json @@ -22,6 +22,7 @@ "apache-airflow-providers-common-sql>=1.3.1", "apache-airflow>=2.4.0", "asgiref", + "asgiref", "boto3>=1.24.0", "jsonpath_ng>=1.5.3", "mypy-boto3-appflow>=1.24.0", diff --git a/tests/providers/amazon/aws/hooks/test_s3.py b/tests/providers/amazon/aws/hooks/test_s3.py index cd23823dcbb80..001584ccbacc3 100644 --- a/tests/providers/amazon/aws/hooks/test_s3.py +++ b/tests/providers/amazon/aws/hooks/test_s3.py @@ -22,8 +22,9 @@ import os import re import tempfile +import unittest from pathlib import Path -from unittest import mock +from unittest import mock, mock as async_mock from unittest.mock import MagicMock, Mock, patch import boto3 @@ -33,7 +34,11 @@ from airflow.exceptions import AirflowException from airflow.models import Connection -from airflow.providers.amazon.aws.hooks.s3 import S3Hook, provide_bucket_name, unify_bucket_name_and_key +from airflow.providers.amazon.aws.hooks.s3 import ( + S3Hook, + provide_bucket_name, + unify_bucket_name_and_key, +) from airflow.utils.timezone import datetime @@ -387,6 +392,294 @@ def test_load_string_acl(self, s3_bucket): response["Grants"][0]["Permission"] == "FULL_CONTROL" ) + @async_mock.patch("airflow.providers.amazon.aws.triggers.s3.S3Hook.async_conn") + @pytest.mark.asyncio + async def test_s3_key_hook_get_file_metadata_async(self, mock_client): + """ + Test check_wildcard_key for a valid response + :return: + """ + test_resp_iter = [ + { + "Contents": [ + {"Key": "test_key", "ETag": "etag1", "LastModified": datetime(2020, 8, 14, 17, 19, 34)}, + {"Key": "test_key2", "ETag": "etag2", "LastModified": datetime(2020, 8, 14, 17, 19, 34)}, + ] + } + ] + mock_paginator = mock.Mock() + mock_paginate = mock.MagicMock() + mock_paginate.__aiter__.return_value = test_resp_iter + mock_paginator.paginate.return_value = mock_paginate + + s3_hook_async = S3Hook(client_type="S3") + mock_client.get_paginator = mock.Mock(return_value=mock_paginator) + task = await s3_hook_async.get_file_metadata_async(mock_client, "test_bucket", "test*") + assert task == [ + {"Key": "test_key", "ETag": "etag1", "LastModified": datetime(2020, 8, 14, 17, 19, 34)}, + {"Key": "test_key2", "ETag": "etag2", "LastModified": datetime(2020, 8, 14, 17, 19, 34)}, + ] + + @pytest.mark.asyncio + @async_mock.patch("airflow.providers.amazon.aws.triggers.s3.S3Hook.async_conn") + async def test_s3_key_hook_get_head_object_with_error_async(self, mock_client): + """ + Test for 404 error if key not found and assert based on response. + :return: + """ + s3_hook_async = S3Hook(client_type="S3", resource_type="S3") + + mock_client.head_object.side_effect = ClientError( + { + "Error": { + "Code": "SomeServiceException", + "Message": "Details/context around the exception or error", + }, + "ResponseMetadata": { + "RequestId": "1234567890ABCDEF", + "HostId": "host ID data will appear here as a hash", + "HTTPStatusCode": 404, + "HTTPHeaders": {"header metadata key/values will appear here"}, + "RetryAttempts": 0, + }, + }, + operation_name="s3", + ) + response = await s3_hook_async.get_head_object_async( + mock_client, "s3://test_bucket/file", "test_bucket" + ) + assert response is None + + @async_mock.patch("airflow.providers.amazon.aws.triggers.s3.S3Hook.async_conn") + @pytest.mark.asyncio + @unittest.expectedFailure + async def test_s3_key_hook_get_head_object_raise_exception_async(self, mock_client): + """ + Test for 500 error if key not found and assert based on response. + :return: + """ + s3_hook_async = S3Hook(client_type="S3", resource_type="S3") + + mock_client.head_object.side_effect = ClientError( + { + "Error": { + "Code": "SomeServiceException", + "Message": "Details/context around the exception or error", + }, + "ResponseMetadata": { + "RequestId": "1234567890ABCDEF", + "HostId": "host ID data will appear here as a hash", + "HTTPStatusCode": 500, + "HTTPHeaders": {"header metadata key/values will appear here"}, + "RetryAttempts": 0, + }, + }, + operation_name="s3", + ) + with pytest.raises(ClientError) as err: + response = await s3_hook_async.get_head_object_async( + mock_client, "s3://test_bucket/file", "test_bucket" + ) + assert isinstance(response, err) + + @pytest.mark.asyncio + @async_mock.patch("airflow.providers.amazon.aws.triggers.s3.S3Hook.async_conn") + async def test_s3_key_hook_get_files_without_wildcard_async(self, mock_client): + """ + Test get_files for a valid response + :return: + """ + test_resp_iter = [ + { + "Contents": [ + {"Key": "test_key", "ETag": "etag1", "LastModified": datetime(2020, 8, 14, 17, 19, 34)}, + {"Key": "test_key2", "ETag": "etag2", "LastModified": datetime(2020, 8, 14, 17, 19, 34)}, + ] + } + ] + mock_paginator = mock.Mock() + mock_paginate = mock.MagicMock() + mock_paginate.__aiter__.return_value = test_resp_iter + mock_paginator.paginate.return_value = mock_paginate + + s3_hook_async = S3Hook(client_type="S3", resource_type="S3") + mock_client.get_paginator = mock.Mock(return_value=mock_paginator) + response = await s3_hook_async.get_files_async(mock_client, "test_bucket", "test.txt", False) + assert response == [] + + @pytest.mark.asyncio + @async_mock.patch("airflow.providers.amazon.aws.triggers.s3.S3Hook.async_conn") + async def test_s3_key_hook_get_files_with_wildcard_async(self, mock_client): + """ + Test get_files for a valid response + :return: + """ + test_resp_iter = [ + { + "Contents": [ + {"Key": "test_key", "ETag": "etag1", "LastModified": datetime(2020, 8, 14, 17, 19, 34)}, + {"Key": "test_key2", "ETag": "etag2", "LastModified": datetime(2020, 8, 14, 17, 19, 34)}, + ] + } + ] + mock_paginator = mock.Mock() + mock_paginate = mock.MagicMock() + mock_paginate.__aiter__.return_value = test_resp_iter + mock_paginator.paginate.return_value = mock_paginate + + s3_hook_async = S3Hook(client_type="S3", resource_type="S3") + mock_client.get_paginator = mock.Mock(return_value=mock_paginator) + response = await s3_hook_async.get_files_async(mock_client, "test_bucket", "test.txt", True) + assert response == [] + + @pytest.mark.asyncio + @async_mock.patch("airflow.providers.amazon.aws.triggers.s3.S3Hook.async_conn") + async def test_s3_key_hook_list_keys_async(self, mock_client): + """ + Test _list_keys for a valid response + :return: + """ + test_resp_iter = [ + { + "Contents": [ + {"Key": "test_key", "ETag": "etag1", "LastModified": datetime(2020, 8, 14, 17, 19, 34)}, + {"Key": "test_key2", "ETag": "etag2", "LastModified": datetime(2020, 8, 14, 17, 19, 34)}, + ] + } + ] + mock_paginator = mock.Mock() + mock_paginate = mock.MagicMock() + mock_paginate.__aiter__.return_value = test_resp_iter + mock_paginator.paginate.return_value = mock_paginate + + s3_hook_async = S3Hook(client_type="S3", resource_type="S3") + mock_client.get_paginator = mock.Mock(return_value=mock_paginator) + response = await s3_hook_async._list_keys_async(mock_client, "test_bucket", "test*") + assert response == ["test_key", "test_key2"] + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "test_first_prefix, test_second_prefix", + [ + ("async-prefix1/", "async-prefix2/"), + ], + ) + @async_mock.patch("airflow.providers.amazon.aws.hooks.s3.S3Hook.async_conn") + async def test_s3_prefix_sensor_hook_list_prefixes_async( + self, mock_client, test_first_prefix, test_second_prefix + ): + """ + Test list_prefixes whether it returns a valid response + """ + test_resp_iter = [{"CommonPrefixes": [{"Prefix": test_first_prefix}, {"Prefix": test_second_prefix}]}] + mock_paginator = mock.Mock() + mock_paginate = mock.MagicMock() + mock_paginate.__aiter__.return_value = test_resp_iter + mock_paginator.paginate.return_value = mock_paginate + + s3_hook_async = S3Hook(client_type="S3", resource_type="S3") + mock_client.get_paginator = mock.Mock(return_value=mock_paginator) + + actual_output = await s3_hook_async.list_prefixes_async(mock_client, "test_bucket", "test") + expected_output = [test_first_prefix, test_second_prefix] + assert expected_output == actual_output + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "mock_prefix, mock_bucket", + [ + ("async-prefix1", "test_bucket"), + ], + ) + @async_mock.patch("airflow.providers.amazon.aws.hooks.s3.S3Hook.async_conn") + @async_mock.patch("airflow.providers.amazon.aws.hooks.s3.S3Hook.list_prefixes_async") + async def test_s3_prefix_sensor_hook_check_for_prefix_async( + self, mock_list_prefixes, mock_client, mock_prefix, mock_bucket + ): + """ + Test that _check_for_prefix method returns True when valid prefix is used and returns False + when invalid prefix is used + """ + mock_list_prefixes.return_value = ["async-prefix1/", "async-prefix2/"] + + s3_hook_async = S3Hook(client_type="S3", resource_type="S3") + + response = await s3_hook_async._check_for_prefix_async( + client=mock_client.return_value, prefix=mock_prefix, bucket_name=mock_bucket, delimiter="/" + ) + + assert response is True + + response = await s3_hook_async._check_for_prefix_async( + client=mock_client.return_value, + prefix="non-existing-prefix", + bucket_name=mock_bucket, + delimiter="/", + ) + + assert response is False + + @pytest.mark.asyncio + @async_mock.patch("airflow.providers.amazon.aws.hooks.s3.S3Hook.get_s3_bucket_key") + @async_mock.patch("airflow.providers.amazon.aws.triggers.s3.S3Hook.get_head_object_async") + @async_mock.patch("airflow.providers.amazon.aws.hooks.s3.S3Hook.async_conn") + async def test_s3__check_key_without_wild_card_async( + self, mock_client, mock_head_object, mock_get_bucket_key + ): + """Test _check_key function""" + mock_get_bucket_key.return_value = "test_bucket", "test.txt" + mock_head_object.return_value = {"ContentLength": 0} + s3_hook_async = S3Hook(client_type="S3", resource_type="S3") + response = await s3_hook_async._check_key_async( + mock_client.return_value, "test_bucket", False, "s3://test_bucket/file/test.txt" + ) + assert response is True + + @pytest.mark.asyncio + @async_mock.patch("airflow.providers.amazon.aws.hooks.s3.S3Hook.get_s3_bucket_key") + @async_mock.patch("airflow.providers.amazon.aws.triggers.s3.S3Hook.get_head_object_async") + @async_mock.patch("airflow.providers.amazon.aws.hooks.s3.S3Hook.async_conn") + async def test_s3__check_key_none_without_wild_card_async( + self, mock_client, mock_head_object, mock_get_bucket_key + ): + """Test _check_key function when get head object returns none""" + mock_get_bucket_key.return_value = "test_bucket", "test.txt" + mock_head_object.return_value = None + s3_hook_async = S3Hook(client_type="S3", resource_type="S3") + response = await s3_hook_async._check_key_async( + mock_client.return_value, "test_bucket", False, "s3://test_bucket/file/test.txt" + ) + assert response is False + + @pytest.mark.asyncio + @async_mock.patch("airflow.providers.amazon.aws.hooks.s3.S3Hook.get_s3_bucket_key") + @async_mock.patch("airflow.providers.amazon.aws.triggers.s3.S3Hook.get_file_metadata_async") + @async_mock.patch("airflow.providers.amazon.aws.hooks.s3.S3Hook.async_conn") + async def test_s3__check_key_with_wild_card_async( + self, mock_client, mock_get_file_metadata, mock_get_bucket_key + ): + """Test _check_key function""" + mock_get_bucket_key.return_value = "test_bucket", "test" + mock_get_file_metadata.return_value = [ + { + "Key": "test_key", + "ETag": "etag1", + "LastModified": datetime(2020, 8, 14, 17, 19, 34), + "Size": 0, + }, + { + "Key": "test_key2", + "ETag": "etag2", + "LastModified": datetime(2020, 8, 14, 17, 19, 34), + "Size": 0, + }, + ] + s3_hook_async = S3Hook(client_type="S3", resource_type="S3") + response = await s3_hook_async._check_key_async( + mock_client.return_value, "test_bucket", True, "test/example_s3_test_file.txt" + ) + assert response is False + def test_load_bytes(self, s3_bucket): hook = S3Hook() hook.load_bytes(b"Content", "my_key", s3_bucket) diff --git a/tests/providers/amazon/aws/sensors/test_s3_key.py b/tests/providers/amazon/aws/sensors/test_s3_key.py index f0832d3df9a89..7c6f435f53df6 100644 --- a/tests/providers/amazon/aws/sensors/test_s3_key.py +++ b/tests/providers/amazon/aws/sensors/test_s3_key.py @@ -225,3 +225,23 @@ def check_fn(files: list) -> bool: mock_head_object.return_value = {"ContentLength": 1} assert op.poke(None) is True + + @mock.patch("airflow.providers.amazon.aws.sensors.s3.S3KeySensor.poke", return_value=False) + def test_s3_key_sensor_execute_complete_success_with_keys(self, mock_poke): + """ + Asserts that a task is completed with success status and check function + """ + + def check_fn(files: list) -> bool: + return all(f.get("Size", 0) > 0 for f in files) + + sensor = S3KeySensor( + task_id="s3_key_sensor_async", + bucket_key="key", + bucket_name="bucket", + check_fn=check_fn, + deferrable=True, + ) + assert ( + sensor.execute_complete(context={}, event={"status": "running", "files": [{"Size": 10}]}) is None + ) diff --git a/tests/providers/amazon/aws/triggers/test_s3.py b/tests/providers/amazon/aws/triggers/test_s3.py new file mode 100644 index 0000000000000..08bba39b08a08 --- /dev/null +++ b/tests/providers/amazon/aws/triggers/test_s3.py @@ -0,0 +1,77 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import asyncio +from unittest import mock as async_mock + +import pytest + +from airflow.providers.amazon.aws.triggers.s3 import ( + S3KeyTrigger, +) + + +class TestS3KeyTrigger: + def test_serialization(self): + """ + Asserts that the TaskStateTrigger correctly serializes its arguments + and classpath. + """ + trigger = S3KeyTrigger( + bucket_key="s3://test_bucket/file", bucket_name="test_bucket", wildcard_match=True + ) + classpath, kwargs = trigger.serialize() + assert classpath == "airflow.providers.amazon.aws.triggers.s3.S3KeyTrigger" + assert kwargs == { + "bucket_name": "test_bucket", + "bucket_key": "s3://test_bucket/file", + "wildcard_match": True, + "aws_conn_id": "aws_default", + "hook_params": {}, + "poke_interval": 5.0, + "should_check_fn": False, + } + + @pytest.mark.asyncio + @async_mock.patch("airflow.providers.amazon.aws.triggers.s3.S3Hook.async_conn") + async def test_run_success(self, mock_client): + """ + Test if the task is run is in triggerr successfully. + """ + mock_client.return_value.check_key.return_value = True + trigger = S3KeyTrigger(bucket_key="s3://test_bucket/file", bucket_name="test_bucket") + task = asyncio.create_task(trigger.run().__anext__()) + await asyncio.sleep(0.5) + + assert task.done() is True + asyncio.get_event_loop().stop() + + @pytest.mark.asyncio + @async_mock.patch("airflow.providers.amazon.aws.triggers.s3.S3Hook.check_key_async") + @async_mock.patch("airflow.providers.amazon.aws.triggers.s3.S3Hook.async_conn") + async def test_run_pending(self, mock_client, mock_check_key_async): + """ + Test if the task is run is in trigger successfully and set check_key to return false. + """ + mock_check_key_async.return_value = False + trigger = S3KeyTrigger(bucket_key="s3://test_bucket/file", bucket_name="test_bucket") + task = asyncio.create_task(trigger.run().__anext__()) + await asyncio.sleep(0.5) + + assert task.done() is False + asyncio.get_event_loop().stop() diff --git a/tests/system/providers/amazon/aws/example_s3.py b/tests/system/providers/amazon/aws/example_s3.py index a18e16b79ad61..87806a0bf0a5a 100644 --- a/tests/system/providers/amazon/aws/example_s3.py +++ b/tests/system/providers/amazon/aws/example_s3.py @@ -171,6 +171,37 @@ def check_fn(files: list) -> bool: ) # [END howto_sensor_s3_key_multiple_keys] + # [START howto_sensor_s3_key_single_key_deferrable] + # Check if a file exists + sensor_one_key_deferrable = S3KeySensor( + task_id="sensor_one_key_deferrable", + bucket_name=bucket_name, + bucket_key=key, + deferrable=True, + ) + # [END howto_sensor_s3_key_single_key_deferrable] + + # [START howto_sensor_s3_key_multiple_keys_deferrable] + # Check if both files exist + sensor_two_keys_deferrable = S3KeySensor( + task_id="sensor_two_keys_deferrable", + bucket_name=bucket_name, + bucket_key=[key, key_2], + deferrable=True, + ) + # [END howto_sensor_s3_key_multiple_keys_deferrable] + + # [START howto_sensor_s3_key_function] + # Check if a file exists and match a certain pattern defined in check_fn + sensor_key_with_function_deferrable = S3KeySensor( + task_id="sensor_key_with_function_deferrable", + bucket_name=bucket_name, + bucket_key=key, + check_fn=check_fn, + deferrable=True, + ) + # [END howto_sensor_s3_key_function] + # [START howto_sensor_s3_key_function] # Check if a file exists and match a certain pattern defined in check_fn sensor_key_with_function = S3KeySensor( @@ -256,6 +287,7 @@ def check_fn(files: list) -> bool: list_prefixes, list_keys, [sensor_one_key, sensor_two_keys, sensor_key_with_function], + [sensor_one_key_deferrable, sensor_two_keys_deferrable, sensor_key_with_function_deferrable], copy_object, file_transform, branching,