diff --git a/airflow/providers/amazon/aws/hooks/s3.py b/airflow/providers/amazon/aws/hooks/s3.py index 5b8dd55467ab7..6c27f0236c27f 100644 --- a/airflow/providers/amazon/aws/hooks/s3.py +++ b/airflow/providers/amazon/aws/hooks/s3.py @@ -496,6 +496,26 @@ def get_file_metadata( files += page["Contents"] return files + async def get_file_metadata_async( + self, + prefix: str, + bucket_name: str | None = None, + page_size: int | None = None, + max_items: int | None = None, + ) -> list: + + config = { + "PageSize": page_size, + "MaxItems": max_items, + } + async with self.async_conn as client: + paginator = client.get_paginator("list_objects_v2") + files = [] + async for page in paginator.paginate(Bucket=bucket_name, Prefix=prefix, PaginationConfig=config): + if "Contents" in page: + files += page["Contents"] + return files + @unify_bucket_name_and_key @provide_bucket_name def head_object(self, key: str, bucket_name: str | None = None) -> dict | None: @@ -517,6 +537,16 @@ def head_object(self, key: str, bucket_name: str | None = None) -> dict | None: else: raise e + async def head_object_async(self, key: str, bucket_name: str | None = None) -> dict | None: + try: + async with self.async_conn as client: + return await client.head_object(Bucket=bucket_name, Key=key) + except ClientError as e: + if e.response["ResponseMetadata"]["HTTPStatusCode"] == 404: + return None + else: + raise e + @unify_bucket_name_and_key @provide_bucket_name def check_for_key(self, key: str, bucket_name: str | None = None) -> bool: diff --git a/airflow/providers/amazon/aws/sensors/s3.py b/airflow/providers/amazon/aws/sensors/s3.py index 7e46b6c91a8f4..6f71484fbc55d 100644 --- a/airflow/providers/amazon/aws/sensors/s3.py +++ b/airflow/providers/amazon/aws/sensors/s3.py @@ -22,7 +22,7 @@ import re from datetime import datetime from functools import cached_property -from typing import TYPE_CHECKING, Callable, Sequence +from typing import TYPE_CHECKING, Any, Callable, Sequence from deprecated import deprecated @@ -71,6 +71,7 @@ def check_fn(files: List) -> bool: - ``path/to/cert/bundle.pem``: A filename of the CA cert bundle to uses. You can specify this argument if you want to use a different CA cert bundle than the one used by botocore. + :param deferrable: If True, the sensor will run in deferrable mode. """ template_fields: Sequence[str] = ("bucket_key", "bucket_name") @@ -84,6 +85,7 @@ def __init__( check_fn: Callable[..., bool] | None = None, aws_conn_id: str = "aws_default", verify: str | bool | None = None, + deferrable: bool = False, **kwargs, ): super().__init__(**kwargs) @@ -93,54 +95,102 @@ def __init__( self.check_fn = check_fn self.aws_conn_id = aws_conn_id self.verify = verify - - def _check_key(self, key): - bucket_name, key = S3Hook.get_s3_bucket_key(self.bucket_name, key, "bucket_name", "bucket_key") - self.log.info("Poking for key : s3://%s/%s", bucket_name, key) - - """ - Set variable `files` which contains a list of dict which contains only the size - If needed we might want to add other attributes later - Format: [{ - 'Size': int - }] - """ - if self.wildcard_match: - prefix = re.split(r"[\[\*\?]", key, 1)[0] - keys = self.hook.get_file_metadata(prefix, bucket_name) - key_matches = [k for k in keys if fnmatch.fnmatch(k["Key"], key)] - if len(key_matches) == 0: - return False - - # Reduce the set of metadata to size only - files = list(map(lambda f: {"Size": f["Size"]}, key_matches)) + self.deferrable = deferrable + + def execute(self, context: Context) -> Any: + if self.deferrable: + from airflow.providers.amazon.aws.triggers.s3 import S3KeyTrigger + + self.defer( + trigger=S3KeyTrigger( + bucket_name=self.bucket_name, + bucket_key=self.bucket_key, + wildcard_match=self.wildcard_match, + aws_conn_id=self.aws_conn_id, + verify=self.verify, + ), + method_name="execute_complete", + ) else: - obj = self.hook.head_object(key, bucket_name) - if obj is None: - return False - files = [{"Size": obj["ContentLength"]}] - - if self.check_fn is not None: - return self.check_fn(files) - - return True + super().execute(context=context) def poke(self, context: Context): if isinstance(self.bucket_key, str): - return self._check_key(self.bucket_key) + self.bucket_keys = [self.bucket_key] else: - return all(self._check_key(key) for key in self.bucket_key) + self.bucket_keys = self.bucket_key + wildcard_keys = [] + objs = [] + bucket_key_names = [] + for i in range(len(self.bucket_keys)): + bucket_key_names.append( + S3Hook.get_s3_bucket_key(self.bucket_name, self.bucket_keys[i], "bucket_name", "bucket_key") + ) + bucket_name = bucket_key_names[i][0] + key = bucket_key_names[i][1] + self.log.info("Poking for key : s3://%s/%s", bucket_name, key) + if self.wildcard_match: + prefix = re.split(r"[\[\*\?]", key, 1)[0] + wildcard_keys.append(self.hook.get_file_metadata(prefix, bucket_name)) + else: + objs.append(self.hook.head_object(key, bucket_name)) + + results = process_files( + self.bucket_keys, self.wildcard_match, wildcard_keys, objs, self.check_fn, bucket_key_names + )[0] + return all(results) @deprecated(reason="use `hook` property instead.") def get_hook(self) -> S3Hook: """Create and return an S3Hook.""" return self.hook + def execute_complete(self, context, event=None): + self.log.info("Inside execute complete") + if event["status"] != "success": + raise AirflowException(f"Error: {event}") + else: + results = [] + self.log.info("Success: %s", event) + if self.check_fn is not None: + for files in event["files_list"]: + results.append(self.check_fn(files)) + return all(results) + @cached_property def hook(self) -> S3Hook: return S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify) +def process_files(bucket_keys, wildcard_match, wildcard_keys, obj, check_fn, bucket_key_names): + results = [] + files_list = [] + for i in range(len(bucket_keys)): + key = bucket_key_names[i][1] + if wildcard_match: + key_matches = [k for k in wildcard_keys[i] if fnmatch.fnmatch(k["Key"], key)] + if len(key_matches) == 0: + results.append(False) + continue + # Reduce the set of metadata to size only + files_list.append(list(map(lambda f: {"Size": f["Size"]}, key_matches))) + else: + if obj[i] is None: + results.append(False) + continue + + files_list.append([{"Size": obj[i]["ContentLength"]}]) + + if check_fn is not None: + for files in files_list: + results.append(check_fn(files)) + continue + + results.append(True) + + return [results, files_list] + + @poke_mode_only class S3KeysUnchangedSensor(BaseSensorOperator): """ diff --git a/airflow/providers/amazon/aws/triggers/s3.py b/airflow/providers/amazon/aws/triggers/s3.py new file mode 100644 index 0000000000000..e8f9f0544f4a4 --- /dev/null +++ b/airflow/providers/amazon/aws/triggers/s3.py @@ -0,0 +1,106 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import asyncio +import re +from functools import cached_property +from typing import Any + +from airflow.providers.amazon.aws.hooks.s3 import S3Hook +from airflow.providers.amazon.aws.sensors.s3 import process_files +from airflow.triggers.base import BaseTrigger, TriggerEvent + + +class S3KeyTrigger(BaseTrigger): + """Trigger for S3KeySensor""" + + def __init__( + self, + *, + bucket_key: str | list[str], + bucket_name: str | None = None, + wildcard_match: bool = False, + aws_conn_id: str = "aws_default", + verify: str | bool | None = None, + poll_interval: int = 60, + ): + self.bucket_name = bucket_name + self.bucket_key = bucket_key + self.wildcard_match = wildcard_match + self.aws_conn_id = aws_conn_id + self.verify = verify + self.poll_interval = poll_interval + + def serialize(self) -> tuple[str, dict[str, Any]]: + return ( + "airflow.providers.amazon.aws.triggers.s3.S3KeyTrigger", + { + "bucket_key": self.bucket_key, + "bucket_name": self.bucket_name, + "wildcard_match": self.wildcard_match, + "aws_conn_id": self.aws_conn_id, + "verify": self.verify, + "poll_interval": self.poll_interval, + }, + ) + + @cached_property + def hook(self) -> S3Hook: + return S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify) + + async def poke(self): + if isinstance(self.bucket_key, str): + self.bucket_keys = [self.bucket_key] + else: + self.bucket_keys = self.bucket_key + + wildcard_keys = [] + obj = [] + bucket_key_names = [] + for i in range(len(self.bucket_keys)): + bucket_key_names.append( + S3Hook.get_s3_bucket_key(self.bucket_name, self.bucket_keys[i], "bucket_name", "bucket_key") + ) + bucket_name = bucket_key_names[i][0] + key = bucket_key_names[i][1] + self.log.info("Poking for key : s3://%s/%s", bucket_name, key) + if self.wildcard_match: + prefix = re.split(r"[\[\*\?]", key, 1)[0] + wildcard_keys.append(await self.hook.get_file_metadata_async(prefix, bucket_name)) + else: + obj.append(await self.hook.head_object_async(key, bucket_name)) + + response = process_files( + self.bucket_keys, self.wildcard_match, wildcard_keys, obj, None, bucket_key_names + ) + return [all(response[0]), response[1]] + + async def run(self): + while True: + response = await self.poke() + if response[0]: + yield TriggerEvent( + { + "status": "success", + "message": "S3KeyTrigger success", + "files_list": response[1], + } + ) + else: + await asyncio.sleep(int(self.poll_interval)) diff --git a/tests/providers/amazon/aws/sensors/test_s3_key.py b/tests/providers/amazon/aws/sensors/test_s3_key.py index f0832d3df9a89..2210b94a37055 100644 --- a/tests/providers/amazon/aws/sensors/test_s3_key.py +++ b/tests/providers/amazon/aws/sensors/test_s3_key.py @@ -21,7 +21,7 @@ import pytest -from airflow.exceptions import AirflowException +from airflow.exceptions import AirflowException, TaskDeferred from airflow.models import DAG, DagRun, TaskInstance from airflow.models.variable import Variable from airflow.providers.amazon.aws.sensors.s3 import S3KeySensor @@ -225,3 +225,29 @@ def check_fn(files: list) -> bool: mock_head_object.return_value = {"ContentLength": 1} assert op.poke(None) is True + + @mock.patch("airflow.providers.amazon.aws.sensors.s3.S3Hook.head_object") + def test_poke_with_check_function_with_multiple_files(self, mock_head_object): + def check_fn(files: list) -> bool: + return all(f.get("Size", 0) > 0 for f in files) + + op = S3KeySensor( + task_id="s3_key_sensor", + bucket_key=["s3://test_bucket/file", "s3://test_bucket_2/file"], + check_fn=check_fn, + ) + + mock_head_object.side_effect = [{"ContentLength": 0}, {"ContentLength": 0}] + assert op.poke(None) is False + + mock_head_object.side_effect = [{"ContentLength": 0}, {"ContentLength": 1}] + assert op.poke(None) is False + + mock_head_object.side_effect = [{"ContentLength": 1}, {"ContentLength": 1}] + assert op.poke(None) is True + + def test_deferrable_mode(self): + op = S3KeySensor(task_id="s3_key_sensor", bucket_key="s3://test_bucket/file", deferrable=True) + + with pytest.raises(TaskDeferred): + op.execute(None) diff --git a/tests/providers/amazon/aws/triggers/test_s3_key_trigger.py b/tests/providers/amazon/aws/triggers/test_s3_key_trigger.py new file mode 100644 index 0000000000000..11a7085bf4146 --- /dev/null +++ b/tests/providers/amazon/aws/triggers/test_s3_key_trigger.py @@ -0,0 +1,358 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from unittest import mock + +import pytest + +from airflow.exceptions import AirflowException +from airflow.providers.amazon.aws.hooks.s3 import S3Hook +from airflow.providers.amazon.aws.sensors.s3 import S3KeySensor +from airflow.providers.amazon.aws.triggers.s3 import S3KeyTrigger +from airflow.triggers.base import TriggerEvent + +TEST_KEY = ["test-key", "test-key"] +TEST_BUCKET = "test_bucket" +TEST_CONN_ID = "test_conn" +TEST_VERIFY = True +TEST_WILDCARD_MATCH = False +TEST_POLL_INTERVAL = 100 + + +class TestS3KeySensor: + def test_s3_key_trigger_serialization(self): + s3_key_trigger = S3KeyTrigger( + bucket_key=TEST_KEY[0], + bucket_name=TEST_BUCKET, + wildcard_match=TEST_WILDCARD_MATCH, + aws_conn_id=TEST_CONN_ID, + verify=TEST_VERIFY, + poll_interval=TEST_POLL_INTERVAL, + ) + + class_path, args = s3_key_trigger.serialize() + + assert class_path == "airflow.providers.amazon.aws.triggers.s3.S3KeyTrigger" + + assert args["bucket_key"] == TEST_KEY[0] + assert args["bucket_name"] == TEST_BUCKET + assert args["aws_conn_id"] == TEST_CONN_ID + assert args["wildcard_match"] is TEST_WILDCARD_MATCH + assert args["verify"] is TEST_VERIFY + assert isinstance(args["poll_interval"], int) + assert args["poll_interval"] == TEST_POLL_INTERVAL + + @pytest.mark.asyncio + @mock.patch.object(S3Hook, "head_object_async") + async def test_s3_key_trigger_run(self, mock): + mock.return_value = { + "ContentLength": 123, + } + + s3_key_trigger = S3KeyTrigger( + bucket_key=TEST_KEY[0], + bucket_name=TEST_BUCKET, + wildcard_match=TEST_WILDCARD_MATCH, + aws_conn_id=TEST_CONN_ID, + verify=TEST_VERIFY, + poll_interval=TEST_POLL_INTERVAL, + ) + + generator = s3_key_trigger.run() + response = await generator.asend(None) + + assert response == TriggerEvent( + {"status": "success", "message": "S3KeyTrigger success", "files_list": [[{"Size": 123}]]} + ) + + @pytest.mark.asyncio + @mock.patch.object(S3Hook, "get_file_metadata_async") + async def test_s3_key_sensor_trigger_run_with_wildcard(self, mock_get_file_metadata_async): + mock_get_file_metadata_async.return_value = [ + { + "Key": "test-key", + "Size": 11, + }, + ] + + s3_key_trigger = S3KeyTrigger( + bucket_key=TEST_KEY[0], + bucket_name=TEST_BUCKET, + wildcard_match=True, + aws_conn_id=TEST_CONN_ID, + verify=TEST_VERIFY, + poll_interval=TEST_POLL_INTERVAL, + ) + + generator = s3_key_trigger.run() + response = await generator.asend(None) + + assert response == TriggerEvent( + {"status": "success", "message": "S3KeyTrigger success", "files_list": [[{"Size": 11}]]} + ) + + @pytest.mark.asyncio + async def test_deferrable_poke_bucket_name_none_and_bucket_key_as_relative_path(self): + """ + Test if exception is raised when bucket_name is None + and bucket_key is provided as relative path rather than s3:// url. + :return: + """ + op = S3KeyTrigger(bucket_key="file_in_bucket") + with pytest.raises(AirflowException): + await op.poke() + + @pytest.mark.asyncio + @mock.patch.object(S3Hook, "head_object_async") + async def test_deferrable_poke_bucket_name_none_and_bucket_key_is_list_and_contain_relative_path( + self, mock_head_object_async + ): + """ + Test if exception is raised when bucket_name is None + and bucket_key is provided with one of the two keys as relative path rather than s3:// url. + :return: + """ + mock_head_object_async.return_value = {"ContentLength": 0} + op = S3KeyTrigger(bucket_key=["s3://test_bucket/file", "file_in_bucket"]) + with pytest.raises(AirflowException): + await op.poke() + + @pytest.mark.asyncio + async def test_deferrable_poke_bucket_name_provided_and_bucket_key_is_s3_url(self): + """ + Test if exception is raised when bucket_name is provided + while bucket_key is provided as a full s3:// url. + :return: + """ + op = S3KeyTrigger(bucket_key="s3://test_bucket/file", bucket_name="test_bucket") + with pytest.raises(TypeError): + await op.poke() + + @pytest.mark.asyncio + @mock.patch.object(S3Hook, "head_object_async") + async def test_deferrable_poke_bucket_name_provided_and_bucket_key_is_list_and_contains_s3_url( + self, mock_head_object_async + ): + """ + Test if exception is raised when bucket_name is provided + while bucket_key contains a full s3:// url. + :return: + """ + mock_head_object_async.return_value = {"ContentLength": 0} + op = S3KeyTrigger( + bucket_key=["test_bucket", "s3://test_bucket/file"], + bucket_name="test_bucket", + ) + with pytest.raises(TypeError): + await op.poke() + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "key, bucket, parsed_key, parsed_bucket", + [ + pytest.param("s3://bucket/key", None, "key", "bucket", id="key as s3url"), + pytest.param("key", "bucket", "key", "bucket", id="separate bucket and key"), + ], + ) + @mock.patch.object(S3Hook, "head_object_async") + async def test_deferrable_poke_parse_bucket_key( + self, mock_head_object_async, key, bucket, parsed_key, parsed_bucket + ): + print(key, bucket, parsed_key, parsed_bucket) + mock_head_object_async.return_value = None + + op = S3KeyTrigger( + bucket_key=key, + bucket_name=bucket, + ) + await op.poke() + + mock_head_object_async.assert_called_once_with(parsed_key, parsed_bucket) + + @pytest.mark.asyncio + @mock.patch.object(S3Hook, "head_object_async") + async def test_deferrable_poke_multiple_files(self, mock_head_object_async): + s3_key_trigger = S3KeyTrigger( + bucket_key=["s3://test_bucket/file1", "s3://test_bucket/file2"], + wildcard_match=False, + aws_conn_id=TEST_CONN_ID, + verify=TEST_VERIFY, + poll_interval=TEST_POLL_INTERVAL, + ) + mock_head_object_async.side_effect = [{"ContentLength": 0}, None] + + response = await s3_key_trigger.poke() + assert response == [False, [[{"Size": 0}]]] + + mock_head_object_async.side_effect = [{"ContentLength": 0}, {"ContentLength": 0}] + + response = await s3_key_trigger.poke() + assert response == [True, [[{"Size": 0}], [{"Size": 0}]]] + + mock_head_object_async.assert_any_call("file1", "test_bucket") + mock_head_object_async.assert_any_call("file2", "test_bucket") + + @pytest.mark.asyncio + @mock.patch.object(S3Hook, "get_file_metadata_async") + async def test_poke_deferrable_wildcard(self, mock_get_file_metadata): + op = S3KeyTrigger(bucket_key="s3://test_bucket/file*", wildcard_match=True) + + mock_get_file_metadata.return_value = [] + assert await op.poke() == [False, []] + mock_get_file_metadata.assert_called_once_with("file", "test_bucket") + + mock_get_file_metadata.return_value = [{"Key": "dummyFile", "Size": 0}] + assert await op.poke() == [False, []] + + mock_get_file_metadata.return_value = [{"Key": "file1", "Size": 12}] + assert await op.poke() == [True, [[{"Size": 12}]]] + + @pytest.mark.asyncio + @mock.patch.object(S3Hook, "get_file_metadata_async") + async def test_poke_deferrable_wildcard_multiple_files(self, mock_get_file_metadata_async): + op = S3KeyTrigger( + bucket_key=["s3://test_bucket/file*", "s3://test_bucket/*.zip"], + wildcard_match=True, + ) + + mock_get_file_metadata_async.side_effect = [[{"Key": "file1", "Size": 123}], []] + assert await op.poke() == [False, [[{"Size": 123}]]] + + mock_get_file_metadata_async.side_effect = [ + [{"Key": "file1", "Size": 123}], + [{"Key": "file2", "Size": 456}], + ] + assert await op.poke() == [False, [[{"Size": 123}]]] + + mock_get_file_metadata_async.side_effect = [ + [{"Key": "file1", "Size": 123}], + [{"Key": "test.zip", "Size": 456}], + ] + assert await op.poke() == [True, [[{"Size": 123}], [{"Size": 456}]]] + + mock_get_file_metadata_async.assert_any_call("file", "test_bucket") + mock_get_file_metadata_async.assert_any_call("", "test_bucket") + + @pytest.mark.asyncio + @mock.patch.object(S3Hook, "head_object_async") + async def test_poke_with_check_function(self, mock_head_object_async): + def check_fn(files: list) -> bool: + return all(f.get("Size", 0) > 0 for f in files) + + mock_head_object_async.return_value = {"ContentLength": 1} + trigger = S3KeyTrigger(bucket_key="s3://test_bucket/file") + generator = trigger.run() + trigger_response = await generator.asend(None) + assert trigger_response == TriggerEvent( + { + "status": "success", + "message": "S3KeyTrigger success", + "files_list": [[{"Size": 1}]], + } + ) + op = S3KeySensor( + task_id="test_poke_with_check_function", + bucket_key="s3://test_bucket/file", + check_fn=check_fn, + ) + response = op.execute_complete(None, event=trigger_response.payload) + + assert response is True + + mock_head_object_async.return_value = {"ContentLength": 0} + trigger = S3KeyTrigger(bucket_key="s3://test_bucket/file") + generator = trigger.run() + trigger_response = await generator.asend(None) + + op = S3KeySensor( + task_id="test_poke_with_check_function", + bucket_key="s3://test_bucket/file", + check_fn=check_fn, + ) + response = op.execute_complete(None, event=trigger_response.payload) + + assert response is False + + @pytest.mark.asyncio + @mock.patch.object(S3Hook, "head_object_async") + async def test_poke_with_check_function_with_multiple_files(self, mock_head_object_async): + def check_fn(files: list) -> bool: + return all(f.get("Size", 0) > 0 for f in files) + + mock_head_object_async.side_effect = [{"ContentLength": 0}, {"ContentLength": 0}] + trigger = S3KeyTrigger(bucket_key=["s3://test_bucket/file", "s3://test_bucket_2/file"]) + generator = trigger.run() + trigger_response = await generator.asend(None) + assert trigger_response == TriggerEvent( + { + "status": "success", + "message": "S3KeyTrigger success", + "files_list": [[{"Size": 0}], [{"Size": 0}]], + } + ) + op = S3KeySensor( + task_id="test_poke_with_check_function", + bucket_key=["s3://test_bucket/file", "s3://test_bucket_2/file"], + check_fn=check_fn, + ) + response = op.execute_complete(None, event=trigger_response.payload) + + assert response is False + + mock_head_object_async.side_effect = [{"ContentLength": 0}, {"ContentLength": 1}] + trigger = S3KeyTrigger(bucket_key=["s3://test_bucket/file", "s3://test_bucket_2/file"]) + generator = trigger.run() + trigger_response = await generator.asend(None) + + assert trigger_response == TriggerEvent( + { + "status": "success", + "message": "S3KeyTrigger success", + "files_list": [[{"Size": 0}], [{"Size": 1}]], + } + ) + op = S3KeySensor( + task_id="test_poke_with_check_function", + bucket_key=["s3://test_bucket/file", "s3://test_bucket_2/file"], + check_fn=check_fn, + ) + response = op.execute_complete(None, event=trigger_response.payload) + + assert response is False + + mock_head_object_async.side_effect = [{"ContentLength": 123}, {"ContentLength": 456}] + trigger = S3KeyTrigger(bucket_key=["s3://test_bucket/file", "s3://test_bucket_2/file"]) + generator = trigger.run() + trigger_response = await generator.asend(None) + + assert trigger_response == TriggerEvent( + { + "status": "success", + "message": "S3KeyTrigger success", + "files_list": [[{"Size": 123}], [{"Size": 456}]], + } + ) + op = S3KeySensor( + task_id="test_poke_with_check_function", + bucket_key=["s3://test_bucket/file", "s3://test_bucket_2/file"], + check_fn=check_fn, + ) + response = op.execute_complete(None, event=trigger_response.payload) + + assert response is True