Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions airflow/providers/amazon/aws/transfers/gcs_to_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@
from __future__ import annotations

import os
import warnings
from typing import TYPE_CHECKING, Sequence

from airflow.exceptions import AirflowProviderDeprecationWarning
from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
from airflow.providers.google.cloud.hooks.gcs import GCSHook
Expand All @@ -40,7 +42,7 @@ class GCSToS3Operator(BaseOperator):
:param bucket: The Google Cloud Storage bucket to find the objects. (templated)
:param prefix: Prefix string which filters objects whose name begin with
this prefix. (templated)
:param delimiter: The delimiter by which you want to filter the objects. (templated)
:param delimiter: (Deprecated) The delimiter by which you want to filter the objects. (templated)
For e.g to lists the CSV files from in a directory in GCS you would use
delimiter='.csv'.
:param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud.
Expand Down Expand Up @@ -76,6 +78,8 @@ class GCSToS3Operator(BaseOperator):
object to be uploaded in S3
:param keep_directory_structure: (Optional) When set to False the path of the file
on the bucket is recreated within path passed in dest_s3_key.
:param match_glob: (Optional) filters objects based on the glob pattern given by the string
(e.g, ``'**/*/.json'``)
"""

template_fields: Sequence[str] = (
Expand All @@ -102,12 +106,19 @@ def __init__(
dest_s3_extra_args: dict | None = None,
s3_acl_policy: str | None = None,
keep_directory_structure: bool = True,
match_glob: str | None = None,
**kwargs,
) -> None:
super().__init__(**kwargs)

self.bucket = bucket
self.prefix = prefix
if delimiter:
warnings.warn(
"Usage of 'delimiter' is deprecated, please use 'match_glob' instead",
AirflowProviderDeprecationWarning,
stacklevel=2,
)
self.delimiter = delimiter
self.gcp_conn_id = gcp_conn_id
self.dest_aws_conn_id = dest_aws_conn_id
Expand All @@ -118,6 +129,7 @@ def __init__(
self.dest_s3_extra_args = dest_s3_extra_args or {}
self.s3_acl_policy = s3_acl_policy
self.keep_directory_structure = keep_directory_structure
self.match_glob = match_glob

def execute(self, context: Context) -> list[str]:
# list all files in an Google Cloud Storage bucket
Expand All @@ -133,7 +145,9 @@ def execute(self, context: Context) -> list[str]:
self.prefix,
)

files = hook.list(bucket_name=self.bucket, prefix=self.prefix, delimiter=self.delimiter)
files = hook.list(
bucket_name=self.bucket, prefix=self.prefix, delimiter=self.delimiter, match_glob=self.match_glob
)

s3_hook = S3Hook(
aws_conn_id=self.dest_aws_conn_id, verify=self.dest_verify, extra_args=self.dest_s3_extra_args
Expand Down
126 changes: 108 additions & 18 deletions airflow/providers/google/cloud/hooks/gcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import os
import shutil
import time
import warnings
from contextlib import contextmanager
from datetime import datetime
from functools import partial
Expand All @@ -44,7 +45,7 @@
from google.cloud.storage.retry import DEFAULT_RETRY
from requests import Session

from airflow.exceptions import AirflowException
from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
from airflow.providers.google.cloud.utils.helpers import normalize_directory_path
from airflow.providers.google.common.consts import CLIENT_INFO
from airflow.providers.google.common.hooks.base_google import GoogleBaseAsyncHook, GoogleBaseHook
Expand Down Expand Up @@ -709,6 +710,7 @@ def list(
max_results: int | None = None,
prefix: str | List[str] | None = None,
delimiter: str | None = None,
match_glob: str | None = None,
):
"""
List all objects from the bucket with the given a single prefix or multiple prefixes.
Expand All @@ -717,9 +719,19 @@ def list(
:param versions: if true, list all versions of the objects
:param max_results: max count of items to return in a single page of responses
:param prefix: string or list of strings which filter objects whose name begin with it/them
:param delimiter: filters objects based on the delimiter (for e.g '.csv')
:param delimiter: (Deprecated) filters objects based on the delimiter (for e.g '.csv')
:param match_glob: (Optional) filters objects based on the glob pattern given by the string
(e.g, ``'**/*/.json'``).
:return: a stream of object names matching the filtering criteria
"""
if delimiter and delimiter != "/":
warnings.warn(
"Usage of 'delimiter' param is deprecated, please use 'match_glob' instead",
AirflowProviderDeprecationWarning,
stacklevel=2,
)
if match_glob and delimiter and delimiter != "/":
raise AirflowException("'match_glob' param cannot be used with 'delimiter' that differs than '/'")
objects = []
if isinstance(prefix, list):
for prefix_item in prefix:
Expand All @@ -730,6 +742,7 @@ def list(
max_results=max_results,
prefix=prefix_item,
delimiter=delimiter,
match_glob=match_glob,
)
)
else:
Expand All @@ -740,6 +753,7 @@ def list(
max_results=max_results,
prefix=prefix,
delimiter=delimiter,
match_glob=match_glob,
)
)
return objects
Expand All @@ -751,6 +765,7 @@ def _list(
max_results: int | None = None,
prefix: str | None = None,
delimiter: str | None = None,
match_glob: str | None = None,
) -> List:
"""
List all objects from the bucket with the give string prefix in name.
Expand All @@ -759,7 +774,9 @@ def _list(
:param versions: if true, list all versions of the objects
:param max_results: max count of items to return in a single page of responses
:param prefix: string which filters objects whose name begin with it
:param delimiter: filters objects based on the delimiter (for e.g '.csv')
:param delimiter: (Deprecated) filters objects based on the delimiter (for e.g '.csv')
:param match_glob: (Optional) filters objects based on the glob pattern given by the string
(e.g, ``'**/*/.json'``).
:return: a stream of object names matching the filtering criteria
"""
client = self.get_conn()
Expand All @@ -768,13 +785,25 @@ def _list(
ids = []
page_token = None
while True:
blobs = bucket.list_blobs(
max_results=max_results,
page_token=page_token,
prefix=prefix,
delimiter=delimiter,
versions=versions,
)
if match_glob:
blobs = self._list_blobs_with_match_glob(
bucket=bucket,
client=client,
match_glob=match_glob,
max_results=max_results,
page_token=page_token,
path=bucket.path + "/o",
prefix=prefix,
versions=versions,
)
else:
blobs = bucket.list_blobs(
max_results=max_results,
page_token=page_token,
prefix=prefix,
delimiter=delimiter,
versions=versions,
)

blob_names = []
for blob in blobs:
Expand All @@ -792,6 +821,52 @@ def _list(
break
return ids

@staticmethod
def _list_blobs_with_match_glob(
bucket,
client,
path: str,
max_results: int | None = None,
page_token: str | None = None,
match_glob: str | None = None,
prefix: str | None = None,
versions: bool | None = None,
) -> Any:
"""
List blobs when match_glob param is given.
This method is a patched version of google.cloud.storage Client.list_blobs().
It is used as a temporary workaround to support "match_glob" param,
as it isn't officially supported by GCS Python client.
(follow `issue #1035<https://github.com/googleapis/python-storage/issues/1035>`__).
"""
from google.api_core import page_iterator
from google.cloud.storage.bucket import _blobs_page_start, _item_to_blob

extra_params: Any = {}
if prefix is not None:
extra_params["prefix"] = prefix
if match_glob is not None:
extra_params["matchGlob"] = match_glob
if versions is not None:
extra_params["versions"] = versions
api_request = functools.partial(
client._connection.api_request, timeout=DEFAULT_TIMEOUT, retry=DEFAULT_RETRY
)

blobs: Any = page_iterator.HTTPIterator(
client=client,
api_request=api_request,
path=path,
item_to_value=_item_to_blob,
page_token=page_token,
max_results=max_results,
extra_params=extra_params,
page_start=_blobs_page_start,
)
blobs.prefixes = set()
blobs.bucket = bucket
return blobs

def list_by_timespan(
self,
bucket_name: str,
Expand All @@ -801,6 +876,7 @@ def list_by_timespan(
max_results: int | None = None,
prefix: str | None = None,
delimiter: str | None = None,
match_glob: str | None = None,
) -> List[str]:
"""
List all objects from the bucket with the give string prefix in name that were
Expand All @@ -813,7 +889,9 @@ def list_by_timespan(
:param max_results: max count of items to return in a single page of responses
:param prefix: prefix string which filters objects whose name begin with
this prefix
:param delimiter: filters objects based on the delimiter (for e.g '.csv')
:param delimiter: (Deprecated) filters objects based on the delimiter (for e.g '.csv')
:param match_glob: (Optional) filters objects based on the glob pattern given by the string
(e.g, ``'**/*/.json'``).
:return: a stream of object names matching the filtering criteria
"""
client = self.get_conn()
Expand All @@ -823,13 +901,25 @@ def list_by_timespan(
page_token = None

while True:
blobs = bucket.list_blobs(
max_results=max_results,
page_token=page_token,
prefix=prefix,
delimiter=delimiter,
versions=versions,
)
if match_glob:
blobs = self._list_blobs_with_match_glob(
bucket=bucket,
client=client,
match_glob=match_glob,
max_results=max_results,
page_token=page_token,
path=bucket.path + "/o",
prefix=prefix,
versions=versions,
)
else:
blobs = bucket.list_blobs(
max_results=max_results,
page_token=page_token,
prefix=prefix,
delimiter=delimiter,
versions=versions,
)

blob_names = []
for blob in blobs:
Expand Down
Loading