Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
3ded5c9
Make S3KeySensor deferrable
sunank200 May 2, 2023
3d1c188
Remove aiobotocore from dependency
sunank200 May 2, 2023
85e6edb
Add the S3 trigger path
sunank200 May 8, 2023
95e74c6
Add the tests
sunank200 May 8, 2023
f613dd0
Merge branch 'main' into s3keysensor
sunank200 May 8, 2023
602d9d6
remove unnecessary tests
sunank200 May 9, 2023
ae83da3
Fix the docs
sunank200 May 9, 2023
6e1f0a0
update docs and example DAG
sunank200 May 10, 2023
11ce1b8
Merge branch 'apache:main' into s3keysensor
sunank200 May 10, 2023
41b0c22
Merge branch 'apache:main' into s3keysensor
sunank200 May 29, 2023
e5f515f
Merge branch 'main' into s3keysensor
sunank200 May 29, 2023
571ffa6
Remove S3HookAsync and use S3Hook with async_conn
sunank200 May 29, 2023
0785452
Add tests for hooks and triggers
sunank200 May 29, 2023
38583a7
remove aiohttp
sunank200 May 29, 2023
c4be0da
remove test for Python 3.7 in Airflow
sunank200 May 29, 2023
e1bbaec
Add more test to system tests
sunank200 May 30, 2023
debf4a7
Merge branch 'main' into s3keysensor
sunank200 May 30, 2023
3701f61
Merge branch 'apache:main' into s3keysensor
sunank200 May 30, 2023
b75c3b9
Merge branch 'main' into s3keysensor
sunank200 May 30, 2023
d7b3188
Merge branch 'apache:main' into s3keysensor
sunank200 Jun 5, 2023
d36487d
add check_fn fix
sunank200 Jun 7, 2023
c3e8d8e
Merge branch 'main' into s3keysensor
sunank200 Jun 7, 2023
6be1156
Add tests for chech_fn
sunank200 Jun 7, 2023
fa49d9c
Remove s3 key unchanged code
sunank200 Jun 7, 2023
08ee769
Add the should_check_fn in serializer
sunank200 Jun 7, 2023
df55976
Update triggers integration-name in providers.yaml
sunank200 Jun 7, 2023
3ac1de4
Refactor integration name from Amazon S3 to Amazon Simple Storage Ser…
sunank200 Jun 7, 2023
545a26b
add type checking
sunank200 Jun 7, 2023
8d18db9
Add . for static checksin doc-strings
sunank200 Jun 7, 2023
9f62328
Merge remote-tracking branch 'upstream/main' into s3keysensor
sunank200 Jun 7, 2023
f598da1
Merge branch 'main' into s3keysensor
sunank200 Jun 7, 2023
01f64fe
Merge branch 's3keysensor' of https://github.com/sunank200/airflow in…
sunank200 Jun 7, 2023
66732fd
Merge branch 'main' into s3keysensor
sunank200 Jun 7, 2023
110a8f9
change doc string
sunank200 Jun 7, 2023
83c3bed
Merge branch 's3keysensor' of https://github.com/sunank200/airflow in…
sunank200 Jun 7, 2023
66e7e4e
Merge branch 'main' into s3keysensor
sunank200 Jun 7, 2023
1b315a1
Merge branch 'main' into s3keysensor
sunank200 Jun 8, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
252 changes: 251 additions & 1 deletion airflow/providers/amazon/aws/hooks/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
"""Interact with AWS S3, using the boto3 library."""
from __future__ import annotations

import asyncio
import fnmatch
import gzip as gz
import io
Expand All @@ -38,6 +39,13 @@
from urllib.parse import urlsplit
from uuid import uuid4

if TYPE_CHECKING:
try:
from aiobotocore.client import AioBaseClient
except ImportError:
pass

from asgiref.sync import sync_to_async
from boto3.s3.transfer import TransferConfig
from botocore.exceptions import ClientError

Expand Down Expand Up @@ -88,6 +96,29 @@ def wrapper(*args, **kwargs) -> T:
return cast(T, wrapper)


def provide_bucket_name_async(func: T) -> T:
"""
Function decorator that provides a bucket name taken from the connection
in case no bucket name has been passed to the function.
"""
function_signature = signature(func)

@wraps(func)
async def wrapper(*args: Any, **kwargs: Any) -> Any:
bound_args = function_signature.bind(*args, **kwargs)

if "bucket_name" not in bound_args.arguments:
self = args[0]
if self.aws_conn_id:
connection = await sync_to_async(self.get_connection)(self.aws_conn_id)
if connection.schema:
bound_args.arguments["bucket_name"] = connection.schema

return await func(*bound_args.args, **bound_args.kwargs)

return cast(T, wrapper)


def unify_bucket_name_and_key(func: T) -> T:
"""
Function decorator that unifies bucket name and key taken from the key
Expand Down Expand Up @@ -228,7 +259,6 @@ def get_s3_bucket_key(
f"If `{bucket_param_name}` is provided, {key_param_name} should be a relative path "
"from root level, rather than a full s3:// url"
)

return bucket, key

@provide_bucket_name
Expand Down Expand Up @@ -363,6 +393,226 @@ def list_prefixes(

return prefixes

@provide_bucket_name_async
@unify_bucket_name_and_key
async def get_head_object_async(
self, client: AioBaseClient, key: str, bucket_name: str | None = None
) -> dict[str, Any] | None:
"""
Retrieves metadata of an object.

:param client: aiobotocore client
:param bucket_name: Name of the bucket in which the file is stored
:param key: S3 key that will point to the file
"""
head_object_val: dict[str, Any] | None = None
try:
head_object_val = await client.head_object(Bucket=bucket_name, Key=key)
return head_object_val
except ClientError as e:
if e.response["ResponseMetadata"]["HTTPStatusCode"] == 404:
return head_object_val
else:
raise e

async def list_prefixes_async(
self,
client: AioBaseClient,
bucket_name: str | None = None,
prefix: str | None = None,
delimiter: str | None = None,
page_size: int | None = None,
max_items: int | None = None,
) -> list[Any]:
"""
Lists prefixes in a bucket under prefix.

:param client: ClientCreatorContext
:param bucket_name: the name of the bucket
:param prefix: a key prefix
:param delimiter: the delimiter marks key hierarchy.
:param page_size: pagination size
:param max_items: maximum items to return
:return: a list of matched prefixes
"""
prefix = prefix or ""
delimiter = delimiter or ""
config = {
"PageSize": page_size,
"MaxItems": max_items,
}

paginator = client.get_paginator("list_objects_v2")
response = paginator.paginate(
Bucket=bucket_name, Prefix=prefix, Delimiter=delimiter, PaginationConfig=config
)

prefixes = []
async for page in response:
if "CommonPrefixes" in page:
for common_prefix in page["CommonPrefixes"]:
prefixes.append(common_prefix["Prefix"])

return prefixes

@provide_bucket_name_async
async def get_file_metadata_async(self, client: AioBaseClient, bucket_name: str, key: str) -> list[Any]:
"""
Gets a list of files that a key matching a wildcard expression exists in a bucket asynchronously.

:param client: aiobotocore client
:param bucket_name: the name of the bucket
:param key: the path to the key
"""
prefix = re.split(r"[\[\*\?]", key, 1)[0]
delimiter = ""
paginator = client.get_paginator("list_objects_v2")
response = paginator.paginate(Bucket=bucket_name, Prefix=prefix, Delimiter=delimiter)
files = []
async for page in response:
if "Contents" in page:
files += page["Contents"]
return files

async def _check_key_async(
self,
client: AioBaseClient,
bucket_val: str,
wildcard_match: bool,
key: str,
) -> bool:
"""
Function to check if wildcard_match is True get list of files that a key matching a wildcard
expression exists in a bucket asynchronously and return the boolean value. If wildcard_match
is False get the head object from the bucket and return the boolean value.

:param client: aiobotocore client
:param bucket_val: the name of the bucket
:param key: S3 keys that will point to the file
:param wildcard_match: the path to the key
"""
bucket_name, key = self.get_s3_bucket_key(bucket_val, key, "bucket_name", "bucket_key")
if wildcard_match:
keys = await self.get_file_metadata_async(client, bucket_name, key)
key_matches = [k for k in keys if fnmatch.fnmatch(k["Key"], key)]
if len(key_matches) == 0:
return False
else:
obj = await self.get_head_object_async(client, key, bucket_name)
if obj is None:
return False

return True

async def check_key_async(
self,
client: AioBaseClient,
bucket: str,
bucket_keys: str | list[str],
wildcard_match: bool,
) -> bool:
"""
Checks for all keys in bucket and returns boolean value.

:param client: aiobotocore client
:param bucket: the name of the bucket
:param bucket_keys: S3 keys that will point to the file
:param wildcard_match: the path to the key
"""
if isinstance(bucket_keys, list):
return all(
await asyncio.gather(
*(self._check_key_async(client, bucket, wildcard_match, key) for key in bucket_keys)
)
)
return await self._check_key_async(client, bucket, wildcard_match, bucket_keys)

async def check_for_prefix_async(
self, client: AioBaseClient, prefix: str, delimiter: str, bucket_name: str | None = None
) -> bool:
"""
Checks that a prefix exists in a bucket.

:param bucket_name: the name of the bucket
:param prefix: a key prefix
:param delimiter: the delimiter marks key hierarchy.
:return: False if the prefix does not exist in the bucket and True if it does.
"""
prefix = prefix + delimiter if prefix[-1] != delimiter else prefix
prefix_split = re.split(rf"(\w+[{delimiter}])$", prefix, 1)
previous_level = prefix_split[0]
plist = await self.list_prefixes_async(client, bucket_name, previous_level, delimiter)
return prefix in plist

async def _check_for_prefix_async(
self, client: AioBaseClient, prefix: str, delimiter: str, bucket_name: str | None = None
) -> bool:
return await self.check_for_prefix_async(
client, prefix=prefix, delimiter=delimiter, bucket_name=bucket_name
)

async def get_files_async(
self,
client: AioBaseClient,
bucket: str,
bucket_keys: str | list[str],
wildcard_match: bool,
delimiter: str | None = "/",
) -> list[Any]:
"""Gets a list of files in the bucket."""
keys: list[Any] = []
for key in bucket_keys:
prefix = key
if wildcard_match:
prefix = re.split(r"[\[\*\?]", key, 1)[0]

paginator = client.get_paginator("list_objects_v2")
response = paginator.paginate(Bucket=bucket, Prefix=prefix, Delimiter=delimiter)
async for page in response:
if "Contents" in page:
_temp = [k for k in page["Contents"] if isinstance(k.get("Size", None), (int, float))]
keys = keys + _temp
return keys

@staticmethod
async def _list_keys_async(
client: AioBaseClient,
bucket_name: str | None = None,
prefix: str | None = None,
delimiter: str | None = None,
page_size: int | None = None,
max_items: int | None = None,
) -> list[str]:
"""
Lists keys in a bucket under prefix and not containing delimiter.

:param bucket_name: the name of the bucket
:param prefix: a key prefix
:param delimiter: the delimiter marks key hierarchy.
:param page_size: pagination size
:param max_items: maximum items to return
:return: a list of matched keys
"""
prefix = prefix or ""
delimiter = delimiter or ""
config = {
"PageSize": page_size,
"MaxItems": max_items,
}

paginator = client.get_paginator("list_objects_v2")
response = paginator.paginate(
Bucket=bucket_name, Prefix=prefix, Delimiter=delimiter, PaginationConfig=config
)

keys = []
async for page in response:
if "Contents" in page:
for k in page["Contents"]:
keys.append(k["Key"])

return keys

def _list_key_object_filter(
self, keys: list, from_datetime: datetime | None = None, to_datetime: datetime | None = None
) -> list:
Expand Down
55 changes: 50 additions & 5 deletions airflow/providers/amazon/aws/sensors/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@
import fnmatch
import os
import re
from datetime import datetime
from datetime import datetime, timedelta
from functools import cached_property
from typing import TYPE_CHECKING, Callable, Sequence
from typing import TYPE_CHECKING, Any, Callable, Sequence, cast

from deprecated import deprecated

Expand All @@ -31,6 +31,7 @@

from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
from airflow.providers.amazon.aws.triggers.s3 import S3KeyTrigger
from airflow.sensors.base import BaseSensorOperator, poke_mode_only


Expand All @@ -48,7 +49,7 @@ class S3KeySensor(BaseSensorOperator):
or relative path from root level. When it's specified as a full s3://
url, please leave bucket_name as `None`
:param bucket_name: Name of the S3 bucket. Only needed when ``bucket_key``
is not provided as a full s3:// url. When specified, all the keys passed to ``bucket_key``
is not provided as a full ``s3://`` url. When specified, all the keys passed to ``bucket_key``
refers to this bucket
:param wildcard_match: whether the bucket_key should be interpreted as a
Unix wildcard pattern
Expand All @@ -61,8 +62,9 @@ class S3KeySensor(BaseSensorOperator):
def check_fn(files: List) -> bool:
return any(f.get('Size', 0) > 1048576 for f in files)
:param aws_conn_id: a reference to the s3 connection
:param verify: Whether or not to verify SSL certificates for S3 connection.
By default SSL certificates are verified.
:param deferrable: Run operator in the deferrable mode
:param verify: Whether to verify SSL certificates for S3 connection.
By default, SSL certificates are verified.
You can provide the following values:

- ``False``: do not validate SSL certificates. SSL will still be used
Expand All @@ -84,6 +86,7 @@ def __init__(
check_fn: Callable[..., bool] | None = None,
aws_conn_id: str = "aws_default",
verify: str | bool | None = None,
deferrable: bool = False,
**kwargs,
):
super().__init__(**kwargs)
Expand All @@ -93,6 +96,7 @@ def __init__(
self.check_fn = check_fn
self.aws_conn_id = aws_conn_id
self.verify = verify
self.deferrable = deferrable

def _check_key(self, key):
bucket_name, key = S3Hook.get_s3_bucket_key(self.bucket_name, key, "bucket_name", "bucket_key")
Expand Down Expand Up @@ -131,6 +135,47 @@ def poke(self, context: Context):
else:
return all(self._check_key(key) for key in self.bucket_key)

def execute(self, context: Context) -> None:
"""Airflow runs this method on the worker and defers using the trigger."""
if not self.deferrable:
super().execute(context)
else:
if not self.poke(context=context):
self._defer()

def _defer(self) -> None:
"""Check for a keys in s3 and defers using the triggerer."""
self.defer(
timeout=timedelta(seconds=self.timeout),
trigger=S3KeyTrigger(
bucket_name=cast(str, self.bucket_name),
bucket_key=self.bucket_key,
wildcard_match=self.wildcard_match,
aws_conn_id=self.aws_conn_id,
verify=self.verify,
poke_interval=self.poke_interval,
should_check_fn=True if self.check_fn else False,
),
method_name="execute_complete",
)

def execute_complete(self, context: Context, event: dict[str, Any]) -> bool | None:
"""
Callback for when the trigger fires - returns immediately.
Relies on trigger to throw an exception, otherwise it assumes execution was
successful.
"""
if event["status"] == "running":
found_keys = self.check_fn(event["files"]) # type: ignore[misc]
if found_keys:
return None
else:
self._defer()

if event["status"] == "error":
raise AirflowException(event["message"])
return None

@deprecated(reason="use `hook` property instead.")
def get_hook(self) -> S3Hook:
"""Create and return an S3Hook."""
Expand Down
Loading