diff --git a/airflow/providers/google/cloud/hooks/cloud_storage_transfer_service.py b/airflow/providers/google/cloud/hooks/cloud_storage_transfer_service.py index 888a5650d4962..61d9225e89562 100644 --- a/airflow/providers/google/cloud/hooks/cloud_storage_transfer_service.py +++ b/airflow/providers/google/cloud/hooks/cloud_storage_transfer_service.py @@ -15,7 +15,16 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""This module contains a Google Storage Transfer Service Hook.""" +""" +This module contains a Google Storage Transfer Service Hook. + +.. spelling:: + + ListTransferJobsAsyncPager + StorageTransferServiceAsyncClient + +""" + from __future__ import annotations import json @@ -24,13 +33,23 @@ import warnings from copy import deepcopy from datetime import timedelta -from typing import Sequence - +from typing import Any, Sequence + +from google.cloud.storage_transfer_v1 import ( + ListTransferJobsRequest, + StorageTransferServiceAsyncClient, + TransferJob, + TransferOperation, +) +from google.cloud.storage_transfer_v1.services.storage_transfer_service.pagers import ( + ListTransferJobsAsyncPager, +) from googleapiclient.discovery import Resource, build from googleapiclient.errors import HttpError +from proto import Message from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning -from airflow.providers.google.common.hooks.base_google import GoogleBaseHook +from airflow.providers.google.common.hooks.base_google import GoogleBaseAsyncHook, GoogleBaseHook log = logging.getLogger(__name__) @@ -60,6 +79,7 @@ class GcpTransferOperationStatus: ACCESS_KEY_ID = "accessKeyId" ALREADY_EXISTING_IN_SINK = "overwriteObjectsAlreadyExistingInSink" AWS_ACCESS_KEY = "awsAccessKey" +AWS_SECRET_ACCESS_KEY = "secretAccessKey" AWS_S3_DATA_SOURCE = "awsS3DataSource" BODY = "body" BUCKET_NAME = "bucketName" @@ -73,6 +93,7 @@ class GcpTransferOperationStatus: GCS_DATA_SOURCE = "gcsDataSource" HOURS = "hours" HTTP_DATA_SOURCE = "httpDataSource" +INCLUDE_PREFIXES = "includePrefixes" JOB_NAME = "name" LIST_URL = "list_url" METADATA = "metadata" @@ -81,6 +102,7 @@ class GcpTransferOperationStatus: NAME = "name" OBJECT_CONDITIONS = "object_conditions" OPERATIONS = "operations" +OVERWRITE_OBJECTS_ALREADY_EXISTING_IN_SINK = "overwriteObjectsAlreadyExistingInSink" PATH = "path" PROJECT_ID = "projectId" SCHEDULE = "schedule" @@ -466,3 +488,50 @@ def operations_contain_expected_statuses( f"Expected: {', '.join(expected_statuses_set)}" ) return False + + +class CloudDataTransferServiceAsyncHook(GoogleBaseAsyncHook): + """Asynchronous hook for Google Storage Transfer Service.""" + + def __init__(self, project_id: str | None = None, **kwargs: Any) -> None: + super().__init__(**kwargs) + self.project_id = project_id + self._client: StorageTransferServiceAsyncClient | None = None + + def get_conn(self) -> StorageTransferServiceAsyncClient: + """ + Returns async connection to the Storage Transfer Service. + + :return: Google Storage Transfer asynchronous client. + """ + if not self._client: + self._client = StorageTransferServiceAsyncClient() + return self._client + + async def get_jobs(self, job_names: list[str]) -> ListTransferJobsAsyncPager: + """ + Gets the latest state of a long-running operations in Google Storage Transfer Service. + + :param job_names: (Required) List of names of the jobs to be fetched. + :return: Object that yields Transfer jobs. + """ + client = self.get_conn() + jobs_list_request = ListTransferJobsRequest( + filter=json.dumps(dict(project_id=self.project_id, job_names=job_names)) + ) + return await client.list_transfer_jobs(request=jobs_list_request) + + async def get_latest_operation(self, job: TransferJob) -> Message | None: + """ + Gets the latest operation of the given TransferJob instance. + + :param job: Transfer job instance. + :return: The latest job operation. + """ + latest_operation_name = job.latest_operation_name + if latest_operation_name: + client = self.get_conn() + response_operation = await client.transport.operations_client.get_operation(latest_operation_name) + operation = TransferOperation.deserialize(response_operation.metadata.value) + return operation + return None diff --git a/airflow/providers/google/cloud/transfers/s3_to_gcs.py b/airflow/providers/google/cloud/transfers/s3_to_gcs.py index 39a2f2a68f56e..ce6ebd46e9411 100644 --- a/airflow/providers/google/cloud/transfers/s3_to_gcs.py +++ b/airflow/providers/google/cloud/transfers/s3_to_gcs.py @@ -17,12 +17,38 @@ # under the License. from __future__ import annotations +from datetime import datetime from tempfile import NamedTemporaryFile -from typing import TYPE_CHECKING, Sequence +from typing import TYPE_CHECKING, Any, Sequence +from airflow.configuration import conf from airflow.exceptions import AirflowException from airflow.providers.amazon.aws.hooks.s3 import S3Hook +from airflow.providers.google.cloud.hooks.cloud_storage_transfer_service import ( + ACCESS_KEY_ID, + AWS_ACCESS_KEY, + AWS_S3_DATA_SOURCE, + AWS_SECRET_ACCESS_KEY, + BUCKET_NAME, + GCS_DATA_SINK, + INCLUDE_PREFIXES, + OBJECT_CONDITIONS, + OVERWRITE_OBJECTS_ALREADY_EXISTING_IN_SINK, + PATH, + PROJECT_ID, + SCHEDULE, + SCHEDULE_END_DATE, + SCHEDULE_START_DATE, + STATUS, + TRANSFER_OPTIONS, + TRANSFER_SPEC, + CloudDataTransferServiceHook, + GcpTransferJobsStatus, +) from airflow.providers.google.cloud.hooks.gcs import GCSHook, _parse_gcs_url, gcs_object_is_directory +from airflow.providers.google.cloud.triggers.cloud_storage_transfer_service import ( + CloudStorageTransferServiceCreateJobsTrigger, +) try: from airflow.providers.amazon.aws.operators.s3 import S3ListOperator @@ -71,7 +97,7 @@ class S3ToGCSOperator(S3ListOperator): where you want to store the files. (templated) :param replace: Whether you want to replace existing destination files or not. - :param gzip: Option to compress file for upload + :param gzip: Option to compress file for upload. Parameter ignored in deferrable mode. :param google_impersonation_chain: Optional Google service account to impersonate using short-term credentials, or chained list of accounts required to get the access_token of the last account in the list, which will be impersonated in the request. @@ -80,7 +106,9 @@ class S3ToGCSOperator(S3ListOperator): If set as a sequence, the identities from the list must grant Service Account Token Creator IAM role to the directly preceding identity, with first account from the list granting this role to the originating account (templated). - + :param deferrable: Run operator in the deferrable mode + :param poll_interval: time in seconds between polling for job completion. + The value is considered only when running in deferrable mode. Must be greater than 0. **Example**: @@ -109,6 +137,7 @@ class S3ToGCSOperator(S3ListOperator): "google_impersonation_chain", ) ui_color = "#e09411" + transfer_job_max_files_number = 1000 def __init__( self, @@ -124,6 +153,8 @@ def __init__( replace=False, gzip=False, google_impersonation_chain: str | Sequence[str] | None = None, + deferrable=conf.getboolean("operators", "default_deferrable", fallback=False), + poll_interval: int = 10, **kwargs, ): @@ -135,6 +166,10 @@ def __init__( self.verify = verify self.gzip = gzip self.google_impersonation_chain = google_impersonation_chain + self.deferrable = deferrable + if poll_interval <= 0: + raise ValueError("Invalid value for poll_interval. Expected value greater than 0") + self.poll_interval = poll_interval def _check_inputs(self) -> None: if self.dest_gcs and not gcs_object_is_directory(self.dest_gcs): @@ -159,23 +194,13 @@ def execute(self, context: Context): if not self.replace: s3_objects = self.exclude_existing_objects(s3_objects=s3_objects, gcs_hook=gcs_hook) - if s3_objects: - hook = S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify) - - dest_gcs_bucket, dest_gcs_object_prefix = _parse_gcs_url(self.dest_gcs) - for obj in s3_objects: - # GCS hook builds its own in-memory file, so we have to create - # and pass the path - file_object = hook.get_key(obj, self.bucket) - with NamedTemporaryFile(mode="wb", delete=True) as file: - file_object.download_fileobj(file) - file.flush() - gcs_file = self.s3_to_gcs_object(s3_object=obj) - gcs_hook.upload(dest_gcs_bucket, gcs_file, file.name, gzip=self.gzip) - - self.log.info("All done, uploaded %d files to Google Cloud Storage", len(s3_objects)) - else: + s3_hook = S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify) + if not s3_objects: self.log.info("In sync, no files needed to be uploaded to Google Cloud Storage") + elif self.deferrable: + self.transfer_files_async(s3_objects, gcs_hook, s3_hook) + else: + self.transfer_files(s3_objects, gcs_hook, s3_hook) return s3_objects @@ -221,3 +246,104 @@ def gcs_to_s3_object(self, gcs_object: str) -> str: if self.apply_gcs_prefix: return self.prefix + s3_object return s3_object + + def transfer_files(self, s3_objects: list[str], gcs_hook: GCSHook, s3_hook: S3Hook) -> None: + if s3_objects: + dest_gcs_bucket, dest_gcs_object_prefix = _parse_gcs_url(self.dest_gcs) + for obj in s3_objects: + # GCS hook builds its own in-memory file, so we have to create + # and pass the path + file_object = s3_hook.get_key(obj, self.bucket) + with NamedTemporaryFile(mode="wb", delete=True) as file: + file_object.download_fileobj(file) + file.flush() + gcs_file = self.s3_to_gcs_object(s3_object=obj) + gcs_hook.upload(dest_gcs_bucket, gcs_file, file.name, gzip=self.gzip) + + self.log.info("All done, uploaded %d files to Google Cloud Storage", len(s3_objects)) + + def transfer_files_async(self, files: list[str], gcs_hook: GCSHook, s3_hook: S3Hook) -> None: + """Submits Google Cloud Storage Transfer Service job to copy files from AWS S3 to GCS.""" + if not len(files): + raise ValueError("List of transferring files cannot be empty") + job_names = self.submit_transfer_jobs(files=files, gcs_hook=gcs_hook, s3_hook=s3_hook) + + self.defer( + trigger=CloudStorageTransferServiceCreateJobsTrigger( + project_id=gcs_hook.project_id, + job_names=job_names, + poll_interval=self.poll_interval, + ), + method_name="execute_complete", + ) + + def submit_transfer_jobs(self, files: list[str], gcs_hook: GCSHook, s3_hook: S3Hook) -> list[str]: + now = datetime.utcnow() + one_time_schedule = {"day": now.day, "month": now.month, "year": now.year} + + gcs_bucket, gcs_prefix = _parse_gcs_url(self.dest_gcs) + config = s3_hook.conn_config + + body: dict[str, Any] = { + PROJECT_ID: gcs_hook.project_id, + STATUS: GcpTransferJobsStatus.ENABLED, + SCHEDULE: { + SCHEDULE_START_DATE: one_time_schedule, + SCHEDULE_END_DATE: one_time_schedule, + }, + TRANSFER_SPEC: { + AWS_S3_DATA_SOURCE: { + BUCKET_NAME: self.bucket, + AWS_ACCESS_KEY: { + ACCESS_KEY_ID: config.aws_access_key_id, + AWS_SECRET_ACCESS_KEY: config.aws_secret_access_key, + }, + }, + OBJECT_CONDITIONS: { + INCLUDE_PREFIXES: [], + }, + GCS_DATA_SINK: {BUCKET_NAME: gcs_bucket, PATH: gcs_prefix}, + TRANSFER_OPTIONS: { + OVERWRITE_OBJECTS_ALREADY_EXISTING_IN_SINK: self.replace, + }, + }, + } + + # max size of the field 'transfer_job.transfer_spec.object_conditions.include_prefixes' is 1000, + # that's why we submit multiple jobs transferring 1000 files each. + # See documentation below + # https://cloud.google.com/storage-transfer/docs/reference/rest/v1/TransferSpec#ObjectConditions + chunk_size = self.transfer_job_max_files_number + job_names = [] + transfer_hook = self.get_transfer_hook() + for i in range(0, len(files), chunk_size): + files_chunk = files[i : i + chunk_size] + body[TRANSFER_SPEC][OBJECT_CONDITIONS][INCLUDE_PREFIXES] = files_chunk + job = transfer_hook.create_transfer_job(body=body) + + s = "s" if len(files_chunk) > 1 else "" + self.log.info(f"Submitted job {job['name']} to transfer {len(files_chunk)} file{s}") + job_names.append(job["name"]) + + if len(files) > chunk_size: + js = "s" if len(job_names) > 1 else "" + fs = "s" if len(files) > 1 else "" + self.log.info(f"Overall submitted {len(job_names)} job{js} to transfer {len(files)} file{fs}") + + return job_names + + def execute_complete(self, context: Context, event: dict[str, Any]) -> None: + """ + Callback for when the trigger fires - returns immediately. + Relies on trigger to throw an exception, otherwise it assumes execution was + successful. + """ + if event["status"] == "error": + raise AirflowException(event["message"]) + self.log.info("%s completed with response %s ", self.task_id, event["message"]) + + def get_transfer_hook(self): + return CloudDataTransferServiceHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.google_impersonation_chain, + ) diff --git a/airflow/providers/google/cloud/triggers/cloud_storage_transfer_service.py b/airflow/providers/google/cloud/triggers/cloud_storage_transfer_service.py new file mode 100644 index 0000000000000..2ea2bed9d837f --- /dev/null +++ b/airflow/providers/google/cloud/triggers/cloud_storage_transfer_service.py @@ -0,0 +1,120 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import asyncio +from typing import Any, AsyncIterator + +from google.api_core.exceptions import GoogleAPIError +from google.cloud.storage_transfer_v1.types import TransferOperation + +from airflow import AirflowException +from airflow.providers.google.cloud.hooks.cloud_storage_transfer_service import ( + CloudDataTransferServiceAsyncHook, +) +from airflow.triggers.base import BaseTrigger, TriggerEvent + + +class CloudStorageTransferServiceCreateJobsTrigger(BaseTrigger): + """ + StorageTransferJobTrigger run on the trigger worker to perform Cloud Storage Transfer job. + + :param job_names: List of transfer jobs names. + :param project_id: GCP project id. + :param poll_interval: Interval in seconds between polls. + """ + + def __init__(self, job_names: list[str], project_id: str | None = None, poll_interval: int = 10) -> None: + super().__init__() + self.project_id = project_id + self.job_names = job_names + self.poll_interval = poll_interval + + def serialize(self) -> tuple[str, dict[str, Any]]: + """Serializes StorageTransferJobsTrigger arguments and classpath.""" + return ( + f"{self.__class__.__module__ }.{self.__class__.__qualname__}", + { + "project_id": self.project_id, + "job_names": self.job_names, + "poll_interval": self.poll_interval, + }, + ) + + async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override] + """Gets current data storage transfer jobs and yields a TriggerEvent.""" + async_hook: CloudDataTransferServiceAsyncHook = self.get_async_hook() + + while True: + self.log.info("Attempting to request jobs statuses") + jobs_completed_successfully = 0 + try: + jobs_pager = await async_hook.get_jobs(job_names=self.job_names) + jobs, awaitable_operations = [], [] + async for job in jobs_pager: + operation = async_hook.get_latest_operation(job) + jobs.append(job) + awaitable_operations.append(operation) + + operations: list[TransferOperation] = await asyncio.gather(*awaitable_operations) + + for job, operation in zip(jobs, operations): + if operation is None: + yield TriggerEvent( + { + "status": "error", + "message": f"Transfer job {job.name} has no latest operation.", + } + ) + return + elif operation.status == TransferOperation.Status.SUCCESS: + jobs_completed_successfully += 1 + elif operation.status in ( + TransferOperation.Status.FAILED, + TransferOperation.Status.ABORTED, + ): + yield TriggerEvent( + { + "status": "error", + "message": f"Transfer operation {operation.name} failed with status " + f"{TransferOperation.Status(operation.status).name}", + } + ) + return + except (GoogleAPIError, AirflowException) as ex: + yield TriggerEvent(dict(status="error", message=str(ex))) + return + + jobs_total = len(self.job_names) + self.log.info("Transfer jobs completed: %s of %s", jobs_completed_successfully, jobs_total) + if jobs_completed_successfully == jobs_total: + s = "s" if jobs_total > 1 else "" + job_names = ", ".join(j for j in self.job_names) + yield TriggerEvent( + { + "status": "success", + "message": f"Transfer job{s} {job_names} completed successfully", + } + ) + return + + self.log.info("Sleeping for %s seconds", self.poll_interval) + await asyncio.sleep(self.poll_interval) + + def get_async_hook(self) -> CloudDataTransferServiceAsyncHook: + return CloudDataTransferServiceAsyncHook(project_id=self.project_id) diff --git a/airflow/providers/google/provider.yaml b/airflow/providers/google/provider.yaml index 02168a12c9535..86a4b11d129f0 100644 --- a/airflow/providers/google/provider.yaml +++ b/airflow/providers/google/provider.yaml @@ -107,6 +107,7 @@ dependencies: - google-cloud-spanner>=3.11.1 - google-cloud-speech>=2.18.0 - google-cloud-storage>=2.7.0 + - google-cloud-storage-transfer>=1.4.1 - google-cloud-tasks>=2.13.0 - google-cloud-texttospeech>=2.14.1 - google-cloud-translate>=3.11.0 @@ -848,6 +849,9 @@ triggers: - integration-name: Google Cloud Composer python-modules: - airflow.providers.google.cloud.triggers.cloud_composer + - integration-name: Google Cloud Storage Transfer Service + python-modules: + - airflow.providers.google.cloud.triggers.cloud_storage_transfer_service - integration-name: Google Cloud SQL python-modules: - airflow.providers.google.cloud.triggers.cloud_sql diff --git a/docs/apache-airflow-providers-google/operators/cloud/cloud_storage_transfer_service.rst b/docs/apache-airflow-providers-google/operators/cloud/cloud_storage_transfer_service.rst index 90976814cd968..9bff27bce9a9c 100644 --- a/docs/apache-airflow-providers-google/operators/cloud/cloud_storage_transfer_service.rst +++ b/docs/apache-airflow-providers-google/operators/cloud/cloud_storage_transfer_service.rst @@ -15,6 +15,10 @@ specific language governing permissions and limitations under the License. +.. spelling:: + + ListTransferJobsAsyncPager + StorageTransferServiceAsyncClient Google Cloud Transfer Service Operators diff --git a/docs/apache-airflow-providers-google/operators/transfer/s3_to_gcs.rst b/docs/apache-airflow-providers-google/operators/transfer/s3_to_gcs.rst index 4cda423cea08f..dc1a3beb362ae 100644 --- a/docs/apache-airflow-providers-google/operators/transfer/s3_to_gcs.rst +++ b/docs/apache-airflow-providers-google/operators/transfer/s3_to_gcs.rst @@ -37,6 +37,16 @@ to transfer data from Amazon S3 to Google Cloud Storage. :start-after: [START howto_transfer_s3togcs_operator] :end-before: [END howto_transfer_s3togcs_operator] +There is a possibility to start S3ToGCSOperator asynchronously using deferrable mode. To do so just add parameter +``deferrable=True`` into the operator call. Under the hood it will delegate data transfer to Google Cloud Storage +Transfer Service. By changing parameter ``poll_interval=10`` you can control frequency of polling a transfer +job status. + +.. exampleinclude::/../tests/system/providers/google/cloud/gcs/example_s3_to_gcs_async.py + :language: python + :start-after: [START howto_transfer_s3togcs_operator_async] + :end-before: [END howto_transfer_s3togcs_operator_async] + Reference ^^^^^^^^^ diff --git a/generated/provider_dependencies.json b/generated/provider_dependencies.json index d80c3d084e01f..09c2211b3aaf1 100644 --- a/generated/provider_dependencies.json +++ b/generated/provider_dependencies.json @@ -425,6 +425,7 @@ "google-cloud-secret-manager>=2.16.0", "google-cloud-spanner>=3.11.1", "google-cloud-speech>=2.18.0", + "google-cloud-storage-transfer>=1.4.1", "google-cloud-storage>=2.7.0", "google-cloud-tasks>=2.13.0", "google-cloud-texttospeech>=2.14.1", diff --git a/tests/providers/google/cloud/hooks/test_cloud_storage_transfer_service_async.py b/tests/providers/google/cloud/hooks/test_cloud_storage_transfer_service_async.py new file mode 100644 index 0000000000000..a27d233fd19f6 --- /dev/null +++ b/tests/providers/google/cloud/hooks/test_cloud_storage_transfer_service_async.py @@ -0,0 +1,116 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import json +from unittest import mock +from unittest.mock import AsyncMock + +import pytest + +from airflow.providers.google.cloud.hooks.cloud_storage_transfer_service import ( + CloudDataTransferServiceAsyncHook, +) +from tests.providers.google.cloud.utils.base_gcp_mock import mock_base_gcp_hook_default_project_id + +TEST_PROJECT_ID = "project-id" +TRANSFER_HOOK_PATH = "airflow.providers.google.cloud.hooks.cloud_storage_transfer_service" + + +@pytest.fixture +def hook_async(): + with mock.patch( + "airflow.providers.google.common.hooks.base_google.GoogleBaseAsyncHook.__init__", + new=mock_base_gcp_hook_default_project_id, + ): + yield CloudDataTransferServiceAsyncHook() + + +class TestCloudDataTransferServiceAsyncHook: + @mock.patch(f"{TRANSFER_HOOK_PATH}.StorageTransferServiceAsyncClient") + def test_get_conn(self, mock_async_client): + expected_value = "Async Hook" + mock_async_client.return_value = expected_value + + hook = CloudDataTransferServiceAsyncHook(project_id=TEST_PROJECT_ID) + conn_0 = hook.get_conn() + assert conn_0 == expected_value + + conn_1 = hook.get_conn() + assert conn_1 == expected_value + assert id(conn_0) == id(conn_1) + + @pytest.mark.asyncio + @mock.patch(f"{TRANSFER_HOOK_PATH}.CloudDataTransferServiceAsyncHook.get_conn") + @mock.patch(f"{TRANSFER_HOOK_PATH}.ListTransferJobsRequest") + async def test_get_jobs(self, mock_list_jobs_request, mock_get_conn): + expected_jobs = AsyncMock() + mock_get_conn.return_value.list_transfer_jobs.side_effect = AsyncMock(return_value=expected_jobs) + + expected_request = mock.MagicMock() + mock_list_jobs_request.return_value = expected_request + + hook = CloudDataTransferServiceAsyncHook(project_id=TEST_PROJECT_ID) + job_names = ["Job0", "Job1"] + jobs = await hook.get_jobs(job_names=job_names) + + assert jobs == expected_jobs + mock_list_jobs_request.assert_called_once_with( + filter=json.dumps(dict(project_id=TEST_PROJECT_ID, job_names=job_names)) + ) + mock_get_conn.return_value.list_transfer_jobs.assert_called_once_with(request=expected_request) + + @pytest.mark.asyncio + @mock.patch(f"{TRANSFER_HOOK_PATH}.CloudDataTransferServiceAsyncHook.get_conn") + @mock.patch(f"{TRANSFER_HOOK_PATH}.TransferOperation.deserialize") + async def test_get_last_operation(self, mock_deserialize, mock_conn, hook_async): + latest_operation_name = "Mock operation name" + operation_metadata_value = "Mock metadata value" + + get_operation = AsyncMock() + get_operation.return_value = mock.MagicMock(metadata=mock.MagicMock(value=operation_metadata_value)) + mock_conn.return_value.transport.operations_client.get_operation = get_operation + + expected_operation = mock.MagicMock() + mock_deserialize.return_value = expected_operation + + operation = await hook_async.get_latest_operation( + job=mock.MagicMock(latest_operation_name=latest_operation_name) + ) + + get_operation.assert_called_once_with(latest_operation_name) + mock_deserialize.assert_called_once_with(operation_metadata_value) + assert operation == expected_operation + + @pytest.mark.asyncio + @mock.patch(f"{TRANSFER_HOOK_PATH}.CloudDataTransferServiceAsyncHook.get_conn") + @mock.patch(f"{TRANSFER_HOOK_PATH}.TransferOperation.deserialize") + async def test_get_last_operation_none(self, mock_deserialize, mock_conn, hook_async): + latest_operation_name = None + expected_operation = None + + get_operation = mock.MagicMock() + mock_conn.return_value.transport.operations_client.get_operation = get_operation + + operation = await hook_async.get_latest_operation( + job=mock.MagicMock(latest_operation_name=latest_operation_name) + ) + + get_operation.assert_not_called() + mock_deserialize.assert_not_called() + assert operation == expected_operation diff --git a/tests/providers/google/cloud/transfers/test_s3_to_gcs.py b/tests/providers/google/cloud/transfers/test_s3_to_gcs.py index 480fbf6f75390..e763abf83fd96 100644 --- a/tests/providers/google/cloud/transfers/test_s3_to_gcs.py +++ b/tests/providers/google/cloud/transfers/test_s3_to_gcs.py @@ -18,11 +18,18 @@ from __future__ import annotations from unittest import mock +from unittest.mock import PropertyMock import pytest +import time_machine +from airflow import AirflowException +from airflow.exceptions import TaskDeferred +from airflow.providers.google.cloud.hooks.cloud_storage_transfer_service import CloudDataTransferServiceHook from airflow.providers.google.cloud.transfers.s3_to_gcs import S3ToGCSOperator +from airflow.utils.timezone import utcnow +PROJECT_ID = "test-project-id" TASK_ID = "test-s3-gcs-operator" S3_BUCKET = "test-bucket" S3_PREFIX = "TEST" @@ -30,14 +37,23 @@ GCS_BUCKET = "gcs-bucket" GCS_BUCKET_URI = "gs://" + GCS_BUCKET GCS_PREFIX = "data/" -GCS_PATH_PREFIX = GCS_BUCKET_URI + "/" + GCS_PREFIX +GCS_BUCKET = "gcs-bucket" +GCS_BLOB = "data/" +GCS_PATH_PREFIX = f"gs://{GCS_BUCKET}/{GCS_BLOB}" MOCK_FILE_1 = "TEST1.csv" MOCK_FILE_2 = "TEST2.csv" MOCK_FILE_3 = "TEST3.csv" MOCK_FILES = [MOCK_FILE_1, MOCK_FILE_2, MOCK_FILE_3] AWS_CONN_ID = "aws_default" +AWS_ACCESS_KEY_ID = "Mock AWS access key id" +AWS_SECRET_ACCESS_KEY = "Mock AWS secret access key" GCS_CONN_ID = "google_cloud_default" IMPERSONATION_CHAIN = ["ACCOUNT_1", "ACCOUNT_2", "ACCOUNT_3"] +DEFERRABLE = False +POLL_INTERVAL = 10 +TRANSFER_JOB_ID_0 = "test-transfer-job-0" +TRANSFER_JOB_ID_1 = "test-transfer-job-1" +TRANSFER_JOBS = [TRANSFER_JOB_ID_0, TRANSFER_JOB_ID_1] APPLY_GCS_PREFIX = False PARAMETRIZED_OBJECT_PATHS = ( "apply_gcs_prefix, s3_prefix, s3_object, gcs_destination, gcs_object", @@ -67,6 +83,8 @@ def test_init(self): dest_gcs=GCS_PATH_PREFIX, google_impersonation_chain=IMPERSONATION_CHAIN, apply_gcs_prefix=APPLY_GCS_PREFIX, + deferrable=DEFERRABLE, + poll_interval=POLL_INTERVAL, ) assert operator.task_id == TASK_ID @@ -77,6 +95,8 @@ def test_init(self): assert operator.dest_gcs == GCS_PATH_PREFIX assert operator.google_impersonation_chain == IMPERSONATION_CHAIN assert operator.apply_gcs_prefix == APPLY_GCS_PREFIX + assert operator.deferrable == DEFERRABLE + assert operator.poll_interval == POLL_INTERVAL @mock.patch("airflow.providers.google.cloud.transfers.s3_to_gcs.S3Hook") @mock.patch("airflow.providers.amazon.aws.operators.s3.S3Hook") @@ -251,3 +271,228 @@ def test_execute_apply_gcs_prefix( ) assert sorted([s3_prefix + s3_object]) == sorted(uploaded_files) + + +class TestS3ToGoogleCloudStorageOperatorDeferrable: + @mock.patch("airflow.providers.google.cloud.transfers.s3_to_gcs.CloudDataTransferServiceHook") + @mock.patch("airflow.providers.google.cloud.transfers.s3_to_gcs.S3Hook") + @mock.patch("airflow.providers.amazon.aws.operators.s3.S3Hook") + @mock.patch("airflow.providers.google.cloud.transfers.s3_to_gcs.GCSHook") + def test_execute_deferrable(self, mock_gcs_hook, mock_s3_super_hook, mock_s3_hook, mock_transfer_hook): + mock_gcs_hook.return_value.project_id = PROJECT_ID + + mock_list_keys = mock.MagicMock() + mock_list_keys.return_value = MOCK_FILES + mock_s3_super_hook.return_value.list_keys = mock_list_keys + mock_s3_hook.conn_config = mock.MagicMock( + aws_access_key_id=AWS_ACCESS_KEY_ID, + aws_secret_access_key=AWS_SECRET_ACCESS_KEY, + ) + + mock_create_transfer_job = mock.MagicMock() + mock_create_transfer_job.return_value = dict(name=TRANSFER_JOB_ID_0) + mock_transfer_hook.return_value.create_transfer_job = mock_create_transfer_job + + operator = S3ToGCSOperator( + task_id=TASK_ID, + bucket=S3_BUCKET, + prefix=S3_PREFIX, + delimiter=S3_DELIMITER, + gcp_conn_id=GCS_CONN_ID, + dest_gcs=GCS_PATH_PREFIX, + aws_conn_id=AWS_CONN_ID, + replace=True, + deferrable=True, + ) + + with pytest.raises(TaskDeferred) as exception_info: + operator.execute(None) + + mock_s3_super_hook.assert_called_once_with(aws_conn_id=AWS_CONN_ID, verify=operator.verify) + mock_list_keys.assert_called_once_with( + bucket_name=S3_BUCKET, prefix=S3_PREFIX, delimiter=S3_DELIMITER, apply_wildcard=False + ) + mock_create_transfer_job.assert_called_once() + assert hasattr(exception_info.value, "trigger") + trigger = exception_info.value.trigger + assert trigger.project_id == PROJECT_ID + assert trigger.job_names == [TRANSFER_JOB_ID_0] + assert trigger.poll_interval == operator.poll_interval + + assert hasattr(exception_info.value, "method_name") + assert exception_info.value.method_name == "execute_complete" + + @mock.patch("airflow.providers.google.cloud.transfers.s3_to_gcs.S3Hook") + @mock.patch("airflow.providers.google.cloud.transfers.s3_to_gcs.GCSHook") + def test_transfer_files_async( + self, + mock_s3_hook, + mock_gcs_hook, + ): + mock_s3_hook.conn_config = mock.MagicMock( + aws_access_key_id=AWS_ACCESS_KEY_ID, + aws_secret_access_key=AWS_SECRET_ACCESS_KEY, + ) + mock_gcs_hook.project_id = PROJECT_ID + expected_job_names = [TRANSFER_JOB_ID_0] + expected_method_name = "execute_complete" + + operator = S3ToGCSOperator( + task_id=TASK_ID, + bucket=S3_BUCKET, + prefix=S3_PREFIX, + delimiter=S3_DELIMITER, + gcp_conn_id=GCS_CONN_ID, + dest_gcs=GCS_PATH_PREFIX, + ) + + with mock.patch.object(operator, "submit_transfer_jobs") as mock_submit_transfer_jobs: + mock_submit_transfer_jobs.return_value = expected_job_names + with pytest.raises(TaskDeferred) as exception_info: + operator.transfer_files_async(files=MOCK_FILES, gcs_hook=mock_gcs_hook, s3_hook=mock_s3_hook) + + mock_submit_transfer_jobs.assert_called_once_with( + files=MOCK_FILES, gcs_hook=mock_gcs_hook, s3_hook=mock_s3_hook + ) + + assert hasattr(exception_info.value, "trigger") + trigger = exception_info.value.trigger + assert trigger.project_id == PROJECT_ID + assert trigger.job_names == expected_job_names + assert trigger.poll_interval == operator.poll_interval + + assert hasattr(exception_info.value, "method_name") + assert exception_info.value.method_name == expected_method_name + + @pytest.mark.parametrize("invalid_poll_interval", [-5, 0]) + def test_init_error_polling_interval(self, invalid_poll_interval): + operator = None + expected_error_message = "Invalid value for poll_interval. Expected value greater than 0" + with pytest.raises(ValueError, match=expected_error_message): + operator = S3ToGCSOperator( + task_id=TASK_ID, + bucket=S3_BUCKET, + prefix=S3_PREFIX, + delimiter=S3_DELIMITER, + gcp_conn_id=GCS_CONN_ID, + dest_gcs=GCS_PATH_PREFIX, + poll_interval=invalid_poll_interval, + ) + assert operator is None + + def test_transfer_files_async_error_no_files(self): + operator = S3ToGCSOperator( + task_id=TASK_ID, + bucket=S3_BUCKET, + prefix=S3_PREFIX, + delimiter=S3_DELIMITER, + gcp_conn_id=GCS_CONN_ID, + dest_gcs=GCS_PATH_PREFIX, + ) + expected_error_message = "List of transferring files cannot be empty" + with pytest.raises(ValueError, match=expected_error_message): + operator.transfer_files_async(files=[], gcs_hook=mock.MagicMock(), s3_hook=mock.MagicMock()) + + @pytest.mark.parametrize( + "file_names, chunks, expected_job_names", + [ + (MOCK_FILES, [MOCK_FILES], [TRANSFER_JOB_ID_0]), + ( + [f"path/to/file{i}" for i in range(2000)], + [ + [f"path/to/file{i}" for i in range(1000)], + [f"path/to/file{i}" for i in range(1000, 2000)], + ], + TRANSFER_JOBS, + ), + ], + ) + @mock.patch("airflow.providers.google.cloud.transfers.s3_to_gcs.S3Hook") + @mock.patch("airflow.providers.google.cloud.transfers.s3_to_gcs.GCSHook") + def test_submit_transfer_jobs( + self, + mock_s3_hook, + mock_gcs_hook, + file_names, + chunks, + expected_job_names, + ): + mock_s3_hook.conn_config = mock.MagicMock( + aws_access_key_id=AWS_ACCESS_KEY_ID, + aws_secret_access_key=AWS_SECRET_ACCESS_KEY, + ) + mock_gcs_hook.project_id = PROJECT_ID + operator = S3ToGCSOperator( + task_id=TASK_ID, + bucket=S3_BUCKET, + prefix=S3_PREFIX, + delimiter=S3_DELIMITER, + gcp_conn_id=GCS_CONN_ID, + dest_gcs=GCS_PATH_PREFIX, + ) + + now_time = utcnow() + with time_machine.travel(now_time): + with mock.patch.object(operator, "get_transfer_hook") as mock_get_transfer_hook: + mock_create_transfer_job = mock.MagicMock( + side_effect=[dict(name=job_name) for job_name in expected_job_names] + ) + mock_get_transfer_hook.return_value = mock.MagicMock( + create_transfer_job=mock_create_transfer_job + ) + job_names = operator.submit_transfer_jobs( + files=file_names, + gcs_hook=mock_gcs_hook, + s3_hook=mock_s3_hook, + ) + + mock_get_transfer_hook.assert_called_once() + mock_create_transfer_job.assert_called() + assert job_names == expected_job_names + + @mock.patch( + "airflow.providers.google.cloud.transfers.s3_to_gcs.S3ToGCSOperator.log", new_callable=PropertyMock + ) + def test_execute_complete_success(self, mock_log): + expected_event_message = "Event message (success)" + event = { + "status": "success", + "message": expected_event_message, + } + operator = S3ToGCSOperator(task_id=TASK_ID, bucket=S3_BUCKET) + operator.execute_complete(context=mock.MagicMock(), event=event) + + mock_log.return_value.info.assert_called_once_with( + "%s completed with response %s ", TASK_ID, event["message"] + ) + + @mock.patch( + "airflow.providers.google.cloud.transfers.s3_to_gcs.S3ToGCSOperator.log", new_callable=PropertyMock + ) + def test_execute_complete_error(self, mock_log): + expected_event_message = "Event error message" + event = { + "status": "error", + "message": expected_event_message, + } + operator = S3ToGCSOperator(task_id=TASK_ID, bucket=S3_BUCKET) + with pytest.raises(AirflowException, match=expected_event_message): + operator.execute_complete(context=mock.MagicMock(), event=event) + + mock_log.return_value.info.assert_not_called() + + def test_get_transfer_hook(self): + operator = S3ToGCSOperator( + task_id=TASK_ID, + bucket=S3_BUCKET, + prefix=S3_PREFIX, + delimiter=S3_DELIMITER, + gcp_conn_id=GCS_CONN_ID, + dest_gcs=GCS_PATH_PREFIX, + google_impersonation_chain=IMPERSONATION_CHAIN, + ) + transfer_hook = operator.get_transfer_hook() + + assert isinstance(transfer_hook, CloudDataTransferServiceHook) + assert transfer_hook.gcp_conn_id == GCS_CONN_ID + assert transfer_hook.impersonation_chain == IMPERSONATION_CHAIN diff --git a/tests/providers/google/cloud/triggers/test_cloud_storage_transfer_service.py b/tests/providers/google/cloud/triggers/test_cloud_storage_transfer_service.py new file mode 100644 index 0000000000000..a7108d69380fa --- /dev/null +++ b/tests/providers/google/cloud/triggers/test_cloud_storage_transfer_service.py @@ -0,0 +1,317 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import sys + +from google.api_core.exceptions import GoogleAPICallError +from google.cloud.storage_transfer_v1 import TransferOperation + +from airflow import AirflowException +from airflow.triggers.base import TriggerEvent + +if sys.version_info < (3, 8): + from asynctest import mock +else: + from unittest import mock + +import pytest + +from airflow.providers.google.cloud.hooks.cloud_storage_transfer_service import ( + CloudDataTransferServiceAsyncHook, +) +from airflow.providers.google.cloud.triggers.cloud_storage_transfer_service import ( + CloudStorageTransferServiceCreateJobsTrigger, +) + +PROJECT_ID = "test-project" +JOB_0 = "test-job-0" +JOB_1 = "test-job-1" +JOB_NAMES = [JOB_0, JOB_1] +LATEST_OPERATION_NAME_0 = "test-latest-operation-0" +LATEST_OPERATION_NAME_1 = "test-latest-operation-1" +LATEST_OPERATION_NAMES = [LATEST_OPERATION_NAME_0, LATEST_OPERATION_NAME_1] +POLL_INTERVAL = 2 +CLASS_PATH = ( + "airflow.providers.google.cloud.triggers.cloud_storage_transfer_service" + ".CloudStorageTransferServiceCreateJobsTrigger" +) +ASYNC_HOOK_CLASS_PATH = ( + "airflow.providers.google.cloud.hooks.cloud_storage_transfer_service.CloudDataTransferServiceAsyncHook" +) +PYTHON_VERSION = (sys.version_info.major, sys.version_info.minor, sys.version_info.micro) + + +@pytest.fixture(scope="session") +def trigger(): + return CloudStorageTransferServiceCreateJobsTrigger( + project_id=PROJECT_ID, job_names=JOB_NAMES, poll_interval=POLL_INTERVAL + ) + + +def mock_jobs(names: list[str], latest_operation_names: list[str | None]): + """Returns object that mocks asynchronous looping over mock jobs""" + jobs = [mock.MagicMock(latest_operation_name=name) for name in latest_operation_names] + for job, name in zip(jobs, names): + job.name = name + mock_obj = mock.MagicMock() + mock_obj.__aiter__.return_value = (job for job in jobs) + return mock_obj + + +def create_mock_operation(status: TransferOperation.Status, name: str) -> mock.MagicMock: + _obj = mock.MagicMock(status=status) + _obj.name = name + return _obj + + +class TestCloudStorageTransferServiceCreateJobsTrigger: + def test_serialize(self, trigger): + class_path, serialized = trigger.serialize() + + assert class_path == CLASS_PATH + assert serialized == { + "project_id": PROJECT_ID, + "job_names": JOB_NAMES, + "poll_interval": POLL_INTERVAL, + } + + def test_get_async_hook(self, trigger): + hook = trigger.get_async_hook() + + assert isinstance(hook, CloudDataTransferServiceAsyncHook) + assert hook.project_id == PROJECT_ID + + @pytest.mark.asyncio + @mock.patch(ASYNC_HOOK_CLASS_PATH + ".get_latest_operation") + @mock.patch(ASYNC_HOOK_CLASS_PATH + ".get_jobs") + async def test_run(self, get_jobs, get_latest_operation, trigger): + get_jobs.return_value = mock_jobs(names=JOB_NAMES, latest_operation_names=LATEST_OPERATION_NAMES) + get_latest_operation.side_effect = [ + create_mock_operation(status=TransferOperation.Status.SUCCESS, name="operation_" + job_name) + for job_name in JOB_NAMES + ] + expected_event = TriggerEvent( + { + "status": "success", + "message": f"Transfer jobs {JOB_0}, {JOB_1} completed successfully", + } + ) + generator = trigger.run() + actual_event = await generator.asend(None) + + assert actual_event == expected_event + + @pytest.mark.skipif( + (3, 7, 7) < PYTHON_VERSION < (3, 8, 0), + reason="Unresolved issue in asynctest: https://github.com/Martiusweb/asynctest/issues/152", + ) + @pytest.mark.parametrize( + "status", + [ + TransferOperation.Status.STATUS_UNSPECIFIED, + TransferOperation.Status.IN_PROGRESS, + TransferOperation.Status.PAUSED, + TransferOperation.Status.QUEUED, + ], + ) + @pytest.mark.asyncio + @mock.patch("asyncio.sleep") + @mock.patch(ASYNC_HOOK_CLASS_PATH + ".get_latest_operation", autospec=True) + @mock.patch(ASYNC_HOOK_CLASS_PATH + ".get_jobs", autospec=True) + async def test_run_poll_interval(self, get_jobs, get_latest_operation, mock_sleep, trigger, status): + get_jobs.side_effect = [ + mock_jobs(names=JOB_NAMES, latest_operation_names=LATEST_OPERATION_NAMES), + mock_jobs(names=JOB_NAMES, latest_operation_names=LATEST_OPERATION_NAMES), + ] + get_latest_operation.side_effect = [ + create_mock_operation(status=status, name="operation_" + job_name) for job_name in JOB_NAMES + ] + [ + create_mock_operation(status=TransferOperation.Status.SUCCESS, name="operation_" + job_name) + for job_name in JOB_NAMES + ] + expected_event = TriggerEvent( + { + "status": "success", + "message": f"Transfer jobs {JOB_0}, {JOB_1} completed successfully", + } + ) + + generator = trigger.run() + actual_event = await generator.asend(None) + + assert actual_event == expected_event + mock_sleep.assert_called_once_with(POLL_INTERVAL) + + @pytest.mark.skipif( + (3, 7, 7) < PYTHON_VERSION < (3, 8, 0), + reason="Unresolved issue in asynctest: https://github.com/Martiusweb/asynctest/issues/152", + ) + @pytest.mark.parametrize( + "latest_operations_names, expected_failed_job", + [ + ([None, LATEST_OPERATION_NAME_1], JOB_0), + ([LATEST_OPERATION_NAME_0, None], JOB_1), + ], + ) + @pytest.mark.asyncio + @mock.patch(ASYNC_HOOK_CLASS_PATH + ".get_latest_operation") + @mock.patch(ASYNC_HOOK_CLASS_PATH + ".get_jobs") + async def test_run_error_job_has_no_latest_operation( + self, get_jobs, get_latest_operation, trigger, latest_operations_names, expected_failed_job + ): + get_jobs.return_value = mock_jobs(names=JOB_NAMES, latest_operation_names=latest_operations_names) + get_latest_operation.side_effect = [ + create_mock_operation(status=TransferOperation.Status.SUCCESS, name="operation_" + job_name) + if job_name + else None + for job_name in latest_operations_names + ] + expected_event = TriggerEvent( + { + "status": "error", + "message": f"Transfer job {expected_failed_job} has no latest operation.", + } + ) + + generator = trigger.run() + actual_event = await generator.asend(None) + + assert actual_event == expected_event + + @pytest.mark.skipif( + (3, 7, 7) < PYTHON_VERSION < (3, 8, 0), + reason="Unresolved issue in asynctest: https://github.com/Martiusweb/asynctest/issues/152", + ) + @pytest.mark.parametrize( + "job_statuses, failed_operation, expected_status", + [ + ( + [TransferOperation.Status.ABORTED, TransferOperation.Status.SUCCESS], + LATEST_OPERATION_NAME_0, + "ABORTED", + ), + ( + [TransferOperation.Status.FAILED, TransferOperation.Status.SUCCESS], + LATEST_OPERATION_NAME_0, + "FAILED", + ), + ( + [TransferOperation.Status.SUCCESS, TransferOperation.Status.ABORTED], + LATEST_OPERATION_NAME_1, + "ABORTED", + ), + ( + [TransferOperation.Status.SUCCESS, TransferOperation.Status.FAILED], + LATEST_OPERATION_NAME_1, + "FAILED", + ), + ], + ) + @pytest.mark.asyncio + @mock.patch(ASYNC_HOOK_CLASS_PATH + ".get_latest_operation") + @mock.patch(ASYNC_HOOK_CLASS_PATH + ".get_jobs") + async def test_run_error_one_job_failed_or_aborted( + self, + get_jobs, + get_latest_operation, + trigger, + job_statuses, + failed_operation, + expected_status, + ): + get_jobs.return_value = mock_jobs(names=JOB_NAMES, latest_operation_names=LATEST_OPERATION_NAMES) + get_latest_operation.side_effect = [ + create_mock_operation(status=status, name=operation_name) + for status, operation_name in zip(job_statuses, LATEST_OPERATION_NAMES) + ] + expected_event = TriggerEvent( + { + "status": "error", + "message": f"Transfer operation {failed_operation} failed with status {expected_status}", + } + ) + + generator = trigger.run() + actual_event = await generator.asend(None) + + assert actual_event == expected_event + + @pytest.mark.asyncio + @mock.patch(ASYNC_HOOK_CLASS_PATH + ".get_latest_operation") + @mock.patch(ASYNC_HOOK_CLASS_PATH + ".get_jobs") + async def test_run_get_jobs_airflow_exception(self, get_jobs, get_latest_operation, trigger): + expected_error_message = "Mock error message" + get_jobs.side_effect = AirflowException(expected_error_message) + + get_latest_operation.side_effect = [ + create_mock_operation(status=TransferOperation.Status.SUCCESS, name="operation_" + job_name) + for job_name in JOB_NAMES + ] + expected_event = TriggerEvent( + { + "status": "error", + "message": expected_error_message, + } + ) + + generator = trigger.run() + actual_event = await generator.asend(None) + + assert actual_event == expected_event + + @pytest.mark.asyncio + @mock.patch(ASYNC_HOOK_CLASS_PATH + ".get_latest_operation") + @mock.patch(ASYNC_HOOK_CLASS_PATH + ".get_jobs") + async def test_run_get_latest_operation_airflow_exception(self, get_jobs, get_latest_operation, trigger): + get_jobs.return_value = mock_jobs(names=JOB_NAMES, latest_operation_names=LATEST_OPERATION_NAMES) + expected_error_message = "Mock error message" + get_latest_operation.side_effect = AirflowException(expected_error_message) + + expected_event = TriggerEvent( + { + "status": "error", + "message": expected_error_message, + } + ) + + generator = trigger.run() + actual_event = await generator.asend(None) + + assert actual_event == expected_event + + @pytest.mark.asyncio + @mock.patch(ASYNC_HOOK_CLASS_PATH + ".get_latest_operation") + @mock.patch(ASYNC_HOOK_CLASS_PATH + ".get_jobs") + async def test_run_get_latest_operation_google_api_call_error( + self, get_jobs, get_latest_operation, trigger + ): + get_jobs.return_value = mock_jobs(names=JOB_NAMES, latest_operation_names=LATEST_OPERATION_NAMES) + error_message = "Mock error message" + get_latest_operation.side_effect = GoogleAPICallError(error_message) + + expected_event = TriggerEvent( + { + "status": "error", + "message": f"{None} {error_message}", + } + ) + + generator = trigger.run() + actual_event = await generator.asend(None) + + assert actual_event == expected_event diff --git a/tests/system/providers/google/cloud/gcs/example_s3_to_gcs_async.py b/tests/system/providers/google/cloud/gcs/example_s3_to_gcs_async.py new file mode 100644 index 0000000000000..ccecc5db4349c --- /dev/null +++ b/tests/system/providers/google/cloud/gcs/example_s3_to_gcs_async.py @@ -0,0 +1,103 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import os +from datetime import datetime + +from airflow import models +from airflow.decorators import task +from airflow.providers.amazon.aws.hooks.s3 import S3Hook +from airflow.providers.amazon.aws.operators.s3 import S3CreateBucketOperator, S3DeleteBucketOperator +from airflow.providers.google.cloud.operators.gcs import GCSCreateBucketOperator, GCSDeleteBucketOperator +from airflow.providers.google.cloud.transfers.s3_to_gcs import S3ToGCSOperator +from airflow.utils.trigger_rule import TriggerRule + +ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID") +GCP_PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT") +DAG_ID = "example_s3_to_gcs" + +BUCKET_NAME = f"bucket_{DAG_ID}_{ENV_ID}" +GCS_BUCKET_URL = f"gs://{BUCKET_NAME}/" +UPLOAD_FILE = "/tmp/example-file.txt" +PREFIX = "TESTS" + + +@task(task_id="upload_file_to_s3") +def upload_file(): + """A callable to upload file to AWS bucket""" + s3_hook = S3Hook() + s3_hook.load_file(filename=UPLOAD_FILE, key=PREFIX, bucket_name=BUCKET_NAME) + + +with models.DAG( + DAG_ID, + schedule="@once", + start_date=datetime(2021, 1, 1), + catchup=False, + tags=["example", "s3"], +) as dag: + create_s3_bucket = S3CreateBucketOperator( + task_id="create_s3_bucket", bucket_name=BUCKET_NAME, region_name="us-east-1" + ) + + create_gcs_bucket = GCSCreateBucketOperator( + task_id="create_bucket", + bucket_name=BUCKET_NAME, + project_id=GCP_PROJECT_ID, + ) + + # [START howto_transfer_s3togcs_operator_async] + transfer_to_gcs = S3ToGCSOperator( + task_id="s3_to_gcs_task", bucket=BUCKET_NAME, prefix=PREFIX, dest_gcs=GCS_BUCKET_URL, deferrable=True + ) + # [END howto_transfer_s3togcs_operator_async] + + delete_s3_bucket = S3DeleteBucketOperator( + task_id="delete_s3_bucket", + bucket_name=BUCKET_NAME, + force_delete=True, + trigger_rule=TriggerRule.ALL_DONE, + ) + + delete_gcs_bucket = GCSDeleteBucketOperator( + task_id="delete_gcs_bucket", bucket_name=BUCKET_NAME, trigger_rule=TriggerRule.ALL_DONE + ) + + ( + # TEST SETUP + create_gcs_bucket + >> create_s3_bucket + >> upload_file() + # TEST BODY + >> transfer_to_gcs + # TEST TEARDOWN + >> delete_s3_bucket + >> delete_gcs_bucket + ) + + from tests.system.utils.watcher import watcher + + # This test needs watcher in order to properly mark success/failure + # when "tearDown" task with trigger rule is part of the DAG + list(dag.tasks) >> watcher() + + +from tests.system.utils import get_test_run # noqa: E402 + +# Needed to run the example DAG with pytest (see: tests/system/README.md#run_via_pytest) +test_run = get_test_run(dag)