-
Notifications
You must be signed in to change notification settings - Fork 16.3k
S3 Key sensor deferrable #31749
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
S3 Key sensor deferrable #31749
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -22,7 +22,7 @@ | |
| import re | ||
| from datetime import datetime | ||
| from functools import cached_property | ||
| from typing import TYPE_CHECKING, Callable, Sequence | ||
| from typing import TYPE_CHECKING, Any, Callable, Sequence | ||
|
|
||
| from deprecated import deprecated | ||
|
|
||
|
|
@@ -71,6 +71,7 @@ def check_fn(files: List) -> bool: | |
| - ``path/to/cert/bundle.pem``: A filename of the CA cert bundle to uses. | ||
| You can specify this argument if you want to use a different | ||
| CA cert bundle than the one used by botocore. | ||
| :param deferrable: If True, the sensor will run in deferrable mode. | ||
| """ | ||
|
|
||
| template_fields: Sequence[str] = ("bucket_key", "bucket_name") | ||
|
|
@@ -84,6 +85,7 @@ def __init__( | |
| check_fn: Callable[..., bool] | None = None, | ||
| aws_conn_id: str = "aws_default", | ||
| verify: str | bool | None = None, | ||
| deferrable: bool = False, | ||
| **kwargs, | ||
| ): | ||
| super().__init__(**kwargs) | ||
|
|
@@ -93,54 +95,102 @@ def __init__( | |
| self.check_fn = check_fn | ||
| self.aws_conn_id = aws_conn_id | ||
| self.verify = verify | ||
|
|
||
| def _check_key(self, key): | ||
| bucket_name, key = S3Hook.get_s3_bucket_key(self.bucket_name, key, "bucket_name", "bucket_key") | ||
| self.log.info("Poking for key : s3://%s/%s", bucket_name, key) | ||
|
|
||
| """ | ||
| Set variable `files` which contains a list of dict which contains only the size | ||
| If needed we might want to add other attributes later | ||
| Format: [{ | ||
| 'Size': int | ||
| }] | ||
| """ | ||
| if self.wildcard_match: | ||
| prefix = re.split(r"[\[\*\?]", key, 1)[0] | ||
| keys = self.hook.get_file_metadata(prefix, bucket_name) | ||
| key_matches = [k for k in keys if fnmatch.fnmatch(k["Key"], key)] | ||
| if len(key_matches) == 0: | ||
| return False | ||
|
|
||
| # Reduce the set of metadata to size only | ||
| files = list(map(lambda f: {"Size": f["Size"]}, key_matches)) | ||
| self.deferrable = deferrable | ||
|
|
||
| def execute(self, context: Context) -> Any: | ||
| if self.deferrable: | ||
| from airflow.providers.amazon.aws.triggers.s3 import S3KeyTrigger | ||
|
|
||
| self.defer( | ||
| trigger=S3KeyTrigger( | ||
| bucket_name=self.bucket_name, | ||
| bucket_key=self.bucket_key, | ||
| wildcard_match=self.wildcard_match, | ||
| aws_conn_id=self.aws_conn_id, | ||
| verify=self.verify, | ||
| ), | ||
| method_name="execute_complete", | ||
| ) | ||
| else: | ||
| obj = self.hook.head_object(key, bucket_name) | ||
| if obj is None: | ||
| return False | ||
| files = [{"Size": obj["ContentLength"]}] | ||
|
|
||
| if self.check_fn is not None: | ||
| return self.check_fn(files) | ||
|
|
||
| return True | ||
| super().execute(context=context) | ||
|
|
||
| def poke(self, context: Context): | ||
| if isinstance(self.bucket_key, str): | ||
| return self._check_key(self.bucket_key) | ||
| self.bucket_keys = [self.bucket_key] | ||
| else: | ||
| return all(self._check_key(key) for key in self.bucket_key) | ||
| self.bucket_keys = self.bucket_key | ||
| wildcard_keys = [] | ||
| objs = [] | ||
| bucket_key_names = [] | ||
| for i in range(len(self.bucket_keys)): | ||
| bucket_key_names.append( | ||
| S3Hook.get_s3_bucket_key(self.bucket_name, self.bucket_keys[i], "bucket_name", "bucket_key") | ||
| ) | ||
| bucket_name = bucket_key_names[i][0] | ||
| key = bucket_key_names[i][1] | ||
| self.log.info("Poking for key : s3://%s/%s", bucket_name, key) | ||
| if self.wildcard_match: | ||
| prefix = re.split(r"[\[\*\?]", key, 1)[0] | ||
| wildcard_keys.append(self.hook.get_file_metadata(prefix, bucket_name)) | ||
| else: | ||
| objs.append(self.hook.head_object(key, bucket_name)) | ||
|
|
||
| results = process_files( | ||
| self.bucket_keys, self.wildcard_match, wildcard_keys, objs, self.check_fn, bucket_key_names | ||
| )[0] | ||
| return all(results) | ||
|
|
||
| @deprecated(reason="use `hook` property instead.") | ||
| def get_hook(self) -> S3Hook: | ||
| """Create and return an S3Hook.""" | ||
| return self.hook | ||
|
|
||
| def execute_complete(self, context, event=None): | ||
| self.log.info("Inside execute complete") | ||
| if event["status"] != "success": | ||
| raise AirflowException(f"Error: {event}") | ||
| else: | ||
| results = [] | ||
| self.log.info("Success: %s", event) | ||
| if self.check_fn is not None: | ||
| for files in event["files_list"]: | ||
| results.append(self.check_fn(files)) | ||
| return all(results) | ||
|
Comment on lines
+155
to
+158
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why build the entire for f in event["files_list"]:
if not self.check_fn(f):
return False
return Trueor return all(self.check_fn(f) for f in event["files_list"]) |
||
|
|
||
| @cached_property | ||
| def hook(self) -> S3Hook: | ||
| return S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify) | ||
|
|
||
|
|
||
| def process_files(bucket_keys, wildcard_match, wildcard_keys, obj, check_fn, bucket_key_names): | ||
| results = [] | ||
| files_list = [] | ||
| for i in range(len(bucket_keys)): | ||
| key = bucket_key_names[i][1] | ||
| if wildcard_match: | ||
| key_matches = [k for k in wildcard_keys[i] if fnmatch.fnmatch(k["Key"], key)] | ||
| if len(key_matches) == 0: | ||
| results.append(False) | ||
| continue | ||
| # Reduce the set of metadata to size only | ||
| files_list.append(list(map(lambda f: {"Size": f["Size"]}, key_matches))) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This list-building code can be improved by using iterators |
||
| else: | ||
| if obj[i] is None: | ||
| results.append(False) | ||
| continue | ||
|
|
||
| files_list.append([{"Size": obj[i]["ContentLength"]}]) | ||
|
|
||
| if check_fn is not None: | ||
| for files in files_list: | ||
| results.append(check_fn(files)) | ||
| continue | ||
|
|
||
| results.append(True) | ||
|
|
||
| return [results, files_list] | ||
|
|
||
|
|
||
| @poke_mode_only | ||
| class S3KeysUnchangedSensor(BaseSensorOperator): | ||
| """ | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,106 @@ | ||
| # | ||
| # Licensed to the Apache Software Foundation (ASF) under one | ||
| # or more contributor license agreements. See the NOTICE file | ||
| # distributed with this work for additional information | ||
| # regarding copyright ownership. The ASF licenses this file | ||
| # to you under the Apache License, Version 2.0 (the | ||
| # "License"); you may not use this file except in compliance | ||
| # with the License. You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, | ||
| # software distributed under the License is distributed on an | ||
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
| # KIND, either express or implied. See the License for the | ||
| # specific language governing permissions and limitations | ||
| # under the License. | ||
| from __future__ import annotations | ||
|
|
||
| import asyncio | ||
| import re | ||
| from functools import cached_property | ||
| from typing import Any | ||
|
|
||
| from airflow.providers.amazon.aws.hooks.s3 import S3Hook | ||
| from airflow.providers.amazon.aws.sensors.s3 import process_files | ||
| from airflow.triggers.base import BaseTrigger, TriggerEvent | ||
|
|
||
|
|
||
| class S3KeyTrigger(BaseTrigger): | ||
| """Trigger for S3KeySensor""" | ||
|
|
||
| def __init__( | ||
| self, | ||
| *, | ||
| bucket_key: str | list[str], | ||
| bucket_name: str | None = None, | ||
| wildcard_match: bool = False, | ||
| aws_conn_id: str = "aws_default", | ||
| verify: str | bool | None = None, | ||
| poll_interval: int = 60, | ||
| ): | ||
| self.bucket_name = bucket_name | ||
| self.bucket_key = bucket_key | ||
| self.wildcard_match = wildcard_match | ||
| self.aws_conn_id = aws_conn_id | ||
| self.verify = verify | ||
| self.poll_interval = poll_interval | ||
|
|
||
| def serialize(self) -> tuple[str, dict[str, Any]]: | ||
| return ( | ||
| "airflow.providers.amazon.aws.triggers.s3.S3KeyTrigger", | ||
| { | ||
| "bucket_key": self.bucket_key, | ||
| "bucket_name": self.bucket_name, | ||
| "wildcard_match": self.wildcard_match, | ||
| "aws_conn_id": self.aws_conn_id, | ||
| "verify": self.verify, | ||
| "poll_interval": self.poll_interval, | ||
| }, | ||
| ) | ||
|
|
||
| @cached_property | ||
| def hook(self) -> S3Hook: | ||
| return S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify) | ||
|
|
||
| async def poke(self): | ||
| if isinstance(self.bucket_key, str): | ||
| self.bucket_keys = [self.bucket_key] | ||
| else: | ||
| self.bucket_keys = self.bucket_key | ||
|
|
||
| wildcard_keys = [] | ||
| obj = [] | ||
| bucket_key_names = [] | ||
| for i in range(len(self.bucket_keys)): | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This loop looks very much like |
||
| bucket_key_names.append( | ||
| S3Hook.get_s3_bucket_key(self.bucket_name, self.bucket_keys[i], "bucket_name", "bucket_key") | ||
| ) | ||
| bucket_name = bucket_key_names[i][0] | ||
| key = bucket_key_names[i][1] | ||
| self.log.info("Poking for key : s3://%s/%s", bucket_name, key) | ||
| if self.wildcard_match: | ||
| prefix = re.split(r"[\[\*\?]", key, 1)[0] | ||
| wildcard_keys.append(await self.hook.get_file_metadata_async(prefix, bucket_name)) | ||
| else: | ||
| obj.append(await self.hook.head_object_async(key, bucket_name)) | ||
|
|
||
| response = process_files( | ||
| self.bucket_keys, self.wildcard_match, wildcard_keys, obj, None, bucket_key_names | ||
| ) | ||
| return [all(response[0]), response[1]] | ||
|
|
||
| async def run(self): | ||
| while True: | ||
| response = await self.poke() | ||
| if response[0]: | ||
| yield TriggerEvent( | ||
| { | ||
| "status": "success", | ||
| "message": "S3KeyTrigger success", | ||
| "files_list": response[1], | ||
| } | ||
| ) | ||
| else: | ||
| await asyncio.sleep(int(self.poll_interval)) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This can use a rewrite with
enumerate