Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions airflow/providers/amazon/aws/transfers/gcs_to_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,11 @@ def execute(self, context: Context) -> list[str]:
# and only keep those files which are present in
# Google Cloud Storage and not in S3
bucket_name, prefix = S3Hook.parse_s3_url(self.dest_s3_key)
# if prefix is empty, do not add "/" at end since it would
# filter all the objects (return empty list) instead of empty
# prefix returning all the objects
if prefix:
prefix = prefix if prefix.endswith("/") else f"{prefix}/"
# look for the bucket and the prefix to avoid look into
# parent directories/keys
existing_files = s3_hook.list_keys(bucket_name, prefix=prefix)
Expand Down
69 changes: 52 additions & 17 deletions tests/providers/amazon/aws/transfers/test_gcs_to_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
S3_BUCKET = "s3://bucket/"
MOCK_FILES = ["TEST1.csv", "TEST2.csv", "TEST3.csv"]
S3_ACL_POLICY = "private-read"
deprecated_call_match = "Usage of 'delimiter' is deprecated, please use 'match_glob' instead"


def _create_test_bucket():
Expand All @@ -47,8 +48,6 @@ def _create_test_bucket():

@mock_s3
class TestGCSToS3Operator:

# Test0: match_glob
@mock.patch("airflow.providers.amazon.aws.transfers.gcs_to_s3.GCSHook")
def test_execute__match_glob(self, mock_hook):
mock_hook.return_value.list.return_value = MOCK_FILES
Expand All @@ -73,15 +72,14 @@ def test_execute__match_glob(self, mock_hook):
bucket_name=GCS_BUCKET, delimiter=None, match_glob=f"**/*{DELIMITER}", prefix=PREFIX
)

# Test1: incremental behaviour (just some files missing)
@mock.patch("airflow.providers.amazon.aws.transfers.gcs_to_s3.GCSHook")
def test_execute_incremental(self, mock_hook):
mock_hook.return_value.list.return_value = MOCK_FILES
with NamedTemporaryFile() as f:
gcs_provide_file = mock_hook.return_value.provide_file
gcs_provide_file.return_value.__enter__.return_value.name = f.name

with pytest.deprecated_call():
with pytest.deprecated_call(match=deprecated_call_match):
operator = GCSToS3Operator(
task_id=TASK_ID,
bucket=GCS_BUCKET,
Expand All @@ -100,15 +98,17 @@ def test_execute_incremental(self, mock_hook):
assert sorted(MOCK_FILES[1:]) == sorted(uploaded_files)
assert sorted(MOCK_FILES) == sorted(hook.list_keys("bucket", delimiter="/"))

# Test2: All the files are already in origin and destination without replace
@mock.patch("airflow.providers.amazon.aws.transfers.gcs_to_s3.GCSHook")
def test_execute_without_replace(self, mock_hook):
"""
Tests scenario where all the files are already in origin and destination without replace
"""
mock_hook.return_value.list.return_value = MOCK_FILES
with NamedTemporaryFile() as f:
gcs_provide_file = mock_hook.return_value.provide_file
gcs_provide_file.return_value.__enter__.return_value.name = f.name

with pytest.deprecated_call():
with pytest.deprecated_call(match=deprecated_call_match):
operator = GCSToS3Operator(
task_id=TASK_ID,
bucket=GCS_BUCKET,
Expand All @@ -128,15 +128,53 @@ def test_execute_without_replace(self, mock_hook):
assert [] == uploaded_files
assert sorted(MOCK_FILES) == sorted(hook.list_keys("bucket", delimiter="/"))

# Test3: There are no files in destination bucket
@pytest.mark.parametrize(
argnames="dest_s3_url",
argvalues=[f"{S3_BUCKET}/test/", f"{S3_BUCKET}/test"],
)
@mock.patch("airflow.providers.amazon.aws.transfers.gcs_to_s3.GCSHook")
def test_execute_without_replace_with_folder_structure(self, mock_hook, dest_s3_url):
mock_files_gcs = [f"test{idx}/{mock_file}" for idx, mock_file in enumerate(MOCK_FILES)]
mock_files_s3 = [f"test/test{idx}/{mock_file}" for idx, mock_file in enumerate(MOCK_FILES)]
mock_hook.return_value.list.return_value = mock_files_gcs

hook, bucket = _create_test_bucket()
for mock_file_s3 in mock_files_s3:
bucket.put_object(Key=mock_file_s3, Body=b"testing")

with NamedTemporaryFile() as f:
gcs_provide_file = mock_hook.return_value.provide_file
gcs_provide_file.return_value.__enter__.return_value.name = f.name

with pytest.deprecated_call(match=deprecated_call_match):
operator = GCSToS3Operator(
task_id=TASK_ID,
bucket=GCS_BUCKET,
prefix=PREFIX,
delimiter=DELIMITER,
dest_aws_conn_id="aws_default",
dest_s3_key=dest_s3_url,
replace=False,
)

# we expect nothing to be uploaded
# and all the MOCK_FILES to be present at the S3 bucket
uploaded_files = operator.execute(None)

assert [] == uploaded_files
assert sorted(mock_files_s3) == sorted(hook.list_keys("bucket", prefix="test/"))

@mock.patch("airflow.providers.amazon.aws.transfers.gcs_to_s3.GCSHook")
def test_execute(self, mock_hook):
"""
Tests the scenario where there are no files in destination bucket
"""
mock_hook.return_value.list.return_value = MOCK_FILES
with NamedTemporaryFile() as f:
gcs_provide_file = mock_hook.return_value.provide_file
gcs_provide_file.return_value.__enter__.return_value.name = f.name

with pytest.deprecated_call():
with pytest.deprecated_call(match=deprecated_call_match):
operator = GCSToS3Operator(
task_id=TASK_ID,
bucket=GCS_BUCKET,
Expand All @@ -154,15 +192,14 @@ def test_execute(self, mock_hook):
assert sorted(MOCK_FILES) == sorted(uploaded_files)
assert sorted(MOCK_FILES) == sorted(hook.list_keys("bucket", delimiter="/"))

# Test4: Destination and Origin are in sync but replace all files in destination
@mock.patch("airflow.providers.amazon.aws.transfers.gcs_to_s3.GCSHook")
def test_execute_with_replace(self, mock_hook):
mock_hook.return_value.list.return_value = MOCK_FILES
with NamedTemporaryFile() as f:
gcs_provide_file = mock_hook.return_value.provide_file
gcs_provide_file.return_value.__enter__.return_value.name = f.name

with pytest.deprecated_call():
with pytest.deprecated_call(match=deprecated_call_match):
operator = GCSToS3Operator(
task_id=TASK_ID,
bucket=GCS_BUCKET,
Expand All @@ -182,15 +219,14 @@ def test_execute_with_replace(self, mock_hook):
assert sorted(MOCK_FILES) == sorted(uploaded_files)
assert sorted(MOCK_FILES) == sorted(hook.list_keys("bucket", delimiter="/"))

# Test5: Incremental sync with replace
@mock.patch("airflow.providers.amazon.aws.transfers.gcs_to_s3.GCSHook")
def test_execute_incremental_with_replace(self, mock_hook):
mock_hook.return_value.list.return_value = MOCK_FILES
with NamedTemporaryFile() as f:
gcs_provide_file = mock_hook.return_value.provide_file
gcs_provide_file.return_value.__enter__.return_value.name = f.name

with pytest.deprecated_call():
with pytest.deprecated_call(match=deprecated_call_match):
operator = GCSToS3Operator(
task_id=TASK_ID,
bucket=GCS_BUCKET,
Expand Down Expand Up @@ -218,7 +254,7 @@ def test_execute_should_handle_with_default_dest_s3_extra_args(self, s3_mock_hoo
s3_mock_hook.return_value = mock.Mock()
s3_mock_hook.parse_s3_url.return_value = mock.Mock()

with pytest.deprecated_call():
with pytest.deprecated_call(match=deprecated_call_match):
operator = GCSToS3Operator(
task_id=TASK_ID,
bucket=GCS_BUCKET,
Expand All @@ -241,7 +277,7 @@ def test_execute_should_pass_dest_s3_extra_args_to_s3_hook(self, s3_mock_hook, m
s3_mock_hook.return_value = mock.Mock()
s3_mock_hook.parse_s3_url.return_value = mock.Mock()

with pytest.deprecated_call():
with pytest.deprecated_call(match=deprecated_call_match):
operator = GCSToS3Operator(
task_id=TASK_ID,
bucket=GCS_BUCKET,
Expand All @@ -259,7 +295,6 @@ def test_execute_should_pass_dest_s3_extra_args_to_s3_hook(self, s3_mock_hook, m
aws_conn_id="aws_default", extra_args={"ContentLanguage": "value"}, verify=None
)

# Test6: s3_acl_policy parameter is set
@mock.patch("airflow.providers.amazon.aws.transfers.gcs_to_s3.GCSHook")
@mock.patch("airflow.providers.amazon.aws.hooks.s3.S3Hook.load_file")
def test_execute_with_s3_acl_policy(self, mock_load_file, mock_gcs_hook):
Expand All @@ -268,7 +303,7 @@ def test_execute_with_s3_acl_policy(self, mock_load_file, mock_gcs_hook):
gcs_provide_file = mock_gcs_hook.return_value.provide_file
gcs_provide_file.return_value.__enter__.return_value.name = f.name

with pytest.deprecated_call():
with pytest.deprecated_call(match=deprecated_call_match):
operator = GCSToS3Operator(
task_id=TASK_ID,
bucket=GCS_BUCKET,
Expand All @@ -293,7 +328,7 @@ def test_execute_without_keep_director_structure(self, mock_hook):
gcs_provide_file = mock_hook.return_value.provide_file
gcs_provide_file.return_value.__enter__.return_value.name = f.name

with pytest.deprecated_call():
with pytest.deprecated_call(match=deprecated_call_match):
operator = GCSToS3Operator(
task_id=TASK_ID,
bucket=GCS_BUCKET,
Expand Down