diff --git a/airflow/providers/google/cloud/hooks/cloud_sql.py b/airflow/providers/google/cloud/hooks/cloud_sql.py index 42abfc1c24639..304a6f88fa3d2 100644 --- a/airflow/providers/google/cloud/hooks/cloud_sql.py +++ b/airflow/providers/google/cloud/hooks/cloud_sql.py @@ -39,16 +39,18 @@ from urllib.parse import quote_plus import httpx +from aiohttp import ClientSession +from gcloud.aio.auth import AioSession, Token from googleapiclient.discovery import Resource, build from googleapiclient.errors import HttpError - -from airflow.exceptions import AirflowException +from requests import Session # Number of retries - used by googleapiclient method calls to perform retries # For requests that are "retriable" +from airflow.exceptions import AirflowException from airflow.hooks.base import BaseHook from airflow.models import Connection -from airflow.providers.google.common.hooks.base_google import GoogleBaseHook, get_field +from airflow.providers.google.common.hooks.base_google import GoogleBaseAsyncHook, GoogleBaseHook, get_field from airflow.providers.mysql.hooks.mysql import MySqlHook from airflow.providers.postgres.hooks.postgres import PostgresHook from airflow.utils.log.logging_mixin import LoggingMixin @@ -300,8 +302,7 @@ def delete_database(self, instance: str, database: str, project_id: str) -> None self._wait_for_operation_to_complete(project_id=project_id, operation_name=operation_name) @GoogleBaseHook.fallback_to_default_project_id - @GoogleBaseHook.operation_in_progress_retry() - def export_instance(self, instance: str, body: dict, project_id: str) -> None: + def export_instance(self, instance: str, body: dict, project_id: str): """ Exports data from a Cloud SQL instance to a Cloud Storage bucket as a SQL dump or CSV file. @@ -321,7 +322,7 @@ def export_instance(self, instance: str, body: dict, project_id: str) -> None: .execute(num_retries=self.num_retries) ) operation_name = response["name"] - self._wait_for_operation_to_complete(project_id=project_id, operation_name=operation_name) + return operation_name @GoogleBaseHook.fallback_to_default_project_id def import_instance(self, instance: str, body: dict, project_id: str) -> None: @@ -376,6 +377,7 @@ def clone_instance(self, instance: str, body: dict, project_id: str) -> None: except HttpError as ex: raise AirflowException(f"Cloning of instance {instance} failed: {ex.content}") + @GoogleBaseHook.fallback_to_default_project_id def _wait_for_operation_to_complete( self, project_id: str, operation_name: str, time_to_sleep: int = TIME_TO_SLEEP_IN_SECONDS ) -> None: @@ -412,6 +414,42 @@ def _wait_for_operation_to_complete( ) +class CloudSQLAsyncHook(GoogleBaseAsyncHook): + """Class to get asynchronous hook for Google Cloud SQL.""" + + sync_hook_class = CloudSQLHook + + async def _get_conn(self, session: Session, url: str): + scopes = [ + "https://www.googleapis.com/auth/cloud-platform", + "https://www.googleapis.com/auth/sqlservice.admin", + ] + + async with Token(scopes=scopes) as token: + session_aio = AioSession(session) + headers = { + "Authorization": f"Bearer {await token.get()}", + } + return await session_aio.get(url=url, headers=headers) + + async def get_operation_name(self, project_id: str, operation_name: str, session): + url = f"https://sqladmin.googleapis.com/sql/v1beta4/projects/{project_id}/operations/{operation_name}" + return await self._get_conn(url=str(url), session=session) + + async def get_operation(self, project_id: str, operation_name: str): + async with ClientSession() as session: + try: + operation = await self.get_operation_name( + project_id=project_id, + operation_name=operation_name, + session=session, + ) + operation = await operation.json(content_type=None) + except HttpError as e: + raise e + return operation + + class CloudSqlProxyRunner(LoggingMixin): """ Downloads and runs cloud-sql-proxy as subprocess of the Python process. diff --git a/airflow/providers/google/cloud/operators/cloud_sql.py b/airflow/providers/google/cloud/operators/cloud_sql.py index 20a254b954ffe..5c77cbd86c948 100644 --- a/airflow/providers/google/cloud/operators/cloud_sql.py +++ b/airflow/providers/google/cloud/operators/cloud_sql.py @@ -28,6 +28,7 @@ from airflow.providers.google.cloud.hooks.cloud_sql import CloudSQLDatabaseHook, CloudSQLHook from airflow.providers.google.cloud.links.cloud_sql import CloudSQLInstanceDatabaseLink, CloudSQLInstanceLink from airflow.providers.google.cloud.operators.cloud_base import GoogleCloudBaseOperator +from airflow.providers.google.cloud.triggers.cloud_sql import CloudSQLExportTrigger from airflow.providers.google.cloud.utils.field_validator import GcpBodyFieldValidator from airflow.providers.google.common.hooks.base_google import get_field from airflow.providers.google.common.links.storage import FileDetailsLink @@ -926,6 +927,9 @@ class CloudSQLExportInstanceOperator(CloudSQLBaseOperator): If set as a sequence, the identities from the list must grant Service Account Token Creator IAM role to the directly preceding identity, with first account from the list granting this role to the originating account (templated). + :param deferrable: Run operator in the deferrable mode. + :param poke_interval: (Deferrable mode only) Time (seconds) to wait between calls + to check the run status. """ # [START gcp_sql_export_template_fields] @@ -951,10 +955,14 @@ def __init__( api_version: str = "v1beta4", validate_body: bool = True, impersonation_chain: str | Sequence[str] | None = None, + deferrable: bool = False, + poke_interval: int = 10, **kwargs, ) -> None: self.body = body self.validate_body = validate_body + self.deferrable = deferrable + self.poke_interval = poke_interval super().__init__( project_id=project_id, instance=instance, @@ -994,7 +1002,38 @@ def execute(self, context: Context) -> None: uri=self.body["exportContext"]["uri"][5:], project_id=self.project_id or hook.project_id, ) - return hook.export_instance(project_id=self.project_id, instance=self.instance, body=self.body) + + operation_name = hook.export_instance( + project_id=self.project_id, instance=self.instance, body=self.body + ) + + if not self.deferrable: + return hook._wait_for_operation_to_complete( + project_id=self.project_id, operation_name=operation_name + ) + else: + self.defer( + trigger=CloudSQLExportTrigger( + operation_name=operation_name, + project_id=self.project_id or hook.project_id, + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + poke_interval=self.poke_interval, + ), + method_name="execute_complete", + ) + + def execute_complete(self, context, event=None) -> None: + """ + Callback for when the trigger fires - returns immediately. + Relies on trigger to throw an exception, otherwise it assumes execution was + successful. + """ + if event["status"] == "success": + self.log.info("Operation %s completed successfully", event["operation_name"]) + else: + self.log.exception("Unexpected error in the operation.") + raise AirflowException(event["message"]) class CloudSQLImportInstanceOperator(CloudSQLBaseOperator): diff --git a/airflow/providers/google/cloud/triggers/cloud_sql.py b/airflow/providers/google/cloud/triggers/cloud_sql.py new file mode 100644 index 0000000000000..7d2cd5a323c64 --- /dev/null +++ b/airflow/providers/google/cloud/triggers/cloud_sql.py @@ -0,0 +1,102 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""This module contains Google Cloud SQL triggers.""" +from __future__ import annotations + +import asyncio +from typing import Sequence + +from airflow.providers.google.cloud.hooks.cloud_sql import CloudSQLAsyncHook, CloudSqlOperationStatus +from airflow.triggers.base import BaseTrigger, TriggerEvent + + +class CloudSQLExportTrigger(BaseTrigger): + """ + Trigger that periodically polls information from Cloud SQL API to verify job status. + Implementation leverages asynchronous transport. + """ + + def __init__( + self, + operation_name: str, + project_id: str | None = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: str | Sequence[str] | None = None, + poke_interval: int = 20, + ): + super().__init__() + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + self.operation_name = operation_name + self.project_id = project_id + self.poke_interval = poke_interval + self.hook = CloudSQLAsyncHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + + def serialize(self): + return ( + "airflow.providers.google.cloud.triggers.cloud_sql.CloudSQLExportTrigger", + { + "operation_name": self.operation_name, + "project_id": self.project_id, + "gcp_conn_id": self.gcp_conn_id, + "impersonation_chain": self.impersonation_chain, + "poke_interval": self.poke_interval, + }, + ) + + async def run(self): + while True: + try: + operation = await self.hook.get_operation( + project_id=self.project_id, operation_name=self.operation_name + ) + if operation["status"] == CloudSqlOperationStatus.DONE: + if "error" in operation: + yield TriggerEvent( + { + "operation_name": operation["name"], + "status": "error", + "message": operation["error"]["message"], + } + ) + return + yield TriggerEvent( + { + "operation_name": operation["name"], + "status": "success", + } + ) + return + else: + self.log.info( + "Operation status is %s, sleeping for %s seconds.", + operation["status"], + self.poke_interval, + ) + await asyncio.sleep(self.poke_interval) + except Exception as e: + self.log.exception("Exception occurred while checking operation status.") + yield TriggerEvent( + { + "status": "failed", + "message": str(e), + } + ) diff --git a/airflow/providers/google/provider.yaml b/airflow/providers/google/provider.yaml index 64fcba31bcd05..3fc2c90b9b529 100644 --- a/airflow/providers/google/provider.yaml +++ b/airflow/providers/google/provider.yaml @@ -847,6 +847,9 @@ triggers: - integration-name: Google Cloud Composer python-modules: - airflow.providers.google.cloud.triggers.cloud_composer + - integration-name: Google Cloud SQL + python-modules: + - airflow.providers.google.cloud.triggers.cloud_sql - integration-name: Google Dataflow python-modules: - airflow.providers.google.cloud.triggers.dataflow diff --git a/docs/apache-airflow-providers-google/operators/cloud/cloud_sql.rst b/docs/apache-airflow-providers-google/operators/cloud/cloud_sql.rst index 80076b95f5f1b..4f7c9428c359f 100644 --- a/docs/apache-airflow-providers-google/operators/cloud/cloud_sql.rst +++ b/docs/apache-airflow-providers-google/operators/cloud/cloud_sql.rst @@ -241,6 +241,14 @@ it will be retrieved from the Google Cloud connection used. Both variants are sh :start-after: [START howto_operator_cloudsql_export] :end-before: [END howto_operator_cloudsql_export] +Also for all this action you can use operator in the deferrable mode: + +.. exampleinclude:: /../../tests/system/providers/google/cloud/cloud_sql/example_cloud_sql_deferrable.py + :language: python + :dedent: 4 + :start-after: [START howto_operator_cloudsql_export_async] + :end-before: [END howto_operator_cloudsql_export_async] + Templating """""""""" diff --git a/tests/providers/google/cloud/hooks/test_cloud_sql.py b/tests/providers/google/cloud/hooks/test_cloud_sql.py index 27eb176da4018..0d90f8d0c9d1f 100644 --- a/tests/providers/google/cloud/hooks/test_cloud_sql.py +++ b/tests/providers/google/cloud/hooks/test_cloud_sql.py @@ -24,13 +24,17 @@ from unittest import mock from unittest.mock import PropertyMock +import aiohttp import httplib2 import pytest +from aiohttp.helpers import TimerNoop from googleapiclient.errors import HttpError +from yarl import URL from airflow.exceptions import AirflowException from airflow.models import Connection from airflow.providers.google.cloud.hooks.cloud_sql import ( + CloudSQLAsyncHook, CloudSQLDatabaseHook, CloudSQLHook, CloudSqlProxyRunner, @@ -40,6 +44,26 @@ mock_base_gcp_hook_no_default_project_id, ) +HOOK_STR = "airflow.providers.google.cloud.hooks.cloud_sql.{}" +PROJECT_ID = "test_project_id" +OPERATION_NAME = "test_operation_name" +OPERATION_URL = ( + f"https://sqladmin.googleapis.com/sql/v1beta4/projects/{PROJECT_ID}/operations/{OPERATION_NAME}" +) + + +@pytest.fixture +def hook_async(): + with mock.patch( + "airflow.providers.google.common.hooks.base_google.GoogleBaseAsyncHook.__init__", + new=mock_base_gcp_hook_default_project_id, + ): + yield CloudSQLAsyncHook() + + +def session(): + return mock.Mock() + class TestGcpSqlHookDefaultProjectId: def test_delegate_to_runtime_error(self): @@ -116,9 +140,6 @@ def test_instance_export(self, wait_for_operation_to_complete, get_conn, mock_ge export_method.assert_called_once_with(body={}, instance="instance", project="example-project") execute_method.assert_called_once_with(num_retries=5) - wait_for_operation_to_complete.assert_called_once_with( - project_id="example-project", operation_name="operation_id" - ) assert 1 == mock_get_credentials.call_count @mock.patch("airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLHook.get_conn") @@ -133,14 +154,9 @@ def test_instance_export_with_in_progress_retry(self, wait_for_operation_to_comp ), {"name": "operation_id"}, ] - wait_for_operation_to_complete.return_value = None - self.cloudsql_hook.export_instance(project_id="example-project", instance="instance", body={}) - - assert 2 == export_method.call_count - assert 2 == execute_method.call_count - wait_for_operation_to_complete.assert_called_once_with( - project_id="example-project", operation_name="operation_id" - ) + with pytest.raises(HttpError): + self.cloudsql_hook.export_instance(project_id="example-project", instance="instance", body={}) + wait_for_operation_to_complete.assert_not_called() @mock.patch( "airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLHook.get_credentials_and_project_id", @@ -551,9 +567,6 @@ def test_instance_export_overridden_project_id( ) export_method.assert_called_once_with(body={}, instance="instance", project="example-project") execute_method.assert_called_once_with(num_retries=5) - wait_for_operation_to_complete.assert_called_once_with( - project_id="example-project", operation_name="operation_id" - ) @mock.patch( "airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id", @@ -1238,3 +1251,75 @@ def test_cloud_sql_proxy_runner_version_nok(self, version): ) with pytest.raises(ValueError, match="The sql_proxy_version should match the regular expression"): runner._get_sql_proxy_download_url() + + +class TestCloudSQLAsyncHook: + @pytest.mark.asyncio + @mock.patch(HOOK_STR.format("CloudSQLAsyncHook._get_conn")) + async def test_async_get_operation_name_should_execute_successfully(self, mocked_conn, hook_async): + await hook_async.get_operation_name( + operation_name=OPERATION_NAME, + project_id=PROJECT_ID, + session=session, + ) + + mocked_conn.assert_awaited_once_with(url=OPERATION_URL, session=session) + + @pytest.mark.asyncio + @mock.patch(HOOK_STR.format("CloudSQLAsyncHook.get_operation_name")) + async def test_async_get_operation_completed_should_execute_successfully(self, mocked_get, hook_async): + response = aiohttp.ClientResponse( + "get", + URL(OPERATION_URL), + request_info=mock.Mock(), + writer=mock.Mock(), + continue100=None, + timer=TimerNoop(), + traces=[], + loop=mock.Mock(), + session=session, + ) + response.status = 200 + mocked_get.return_value = response + mocked_get.return_value._headers = {"Authorization": "test-token"} + mocked_get.return_value._body = b'{"status": "DONE"}' + + operation = await hook_async.get_operation(operation_name=OPERATION_NAME, project_id=PROJECT_ID) + mocked_get.assert_awaited_once() + assert operation["status"] == "DONE" + + @pytest.mark.asyncio + @mock.patch(HOOK_STR.format("CloudSQLAsyncHook.get_operation_name")) + async def test_async_get_operation_running_should_execute_successfully(self, mocked_get, hook_async): + response = aiohttp.ClientResponse( + "get", + URL(OPERATION_URL), + request_info=mock.Mock(), + writer=mock.Mock(), + continue100=None, + timer=TimerNoop(), + traces=[], + loop=mock.Mock(), + session=session, + ) + response.status = 200 + mocked_get.return_value = response + mocked_get.return_value._headers = {"Authorization": "test-token"} + mocked_get.return_value._body = b'{"status": "RUNNING"}' + + operation = await hook_async.get_operation(operation_name=OPERATION_NAME, project_id=PROJECT_ID) + mocked_get.assert_awaited_once() + assert operation["status"] == "RUNNING" + + @pytest.mark.asyncio + @mock.patch(HOOK_STR.format("CloudSQLAsyncHook._get_conn")) + async def test_async_get_operation_exception_should_execute_successfully( + self, mocked_get_conn, hook_async + ): + """Assets that the logging is done correctly when CloudSQLAsyncHook raises HttpError""" + + mocked_get_conn.side_effect = HttpError( + resp=mock.MagicMock(status=409), content=b"Operation already exists" + ) + with pytest.raises(HttpError): + await hook_async.get_operation(operation_name=OPERATION_NAME, project_id=PROJECT_ID) diff --git a/tests/providers/google/cloud/operators/test_cloud_sql.py b/tests/providers/google/cloud/operators/test_cloud_sql.py index 888dc11ebba1f..903c5e3c41c98 100644 --- a/tests/providers/google/cloud/operators/test_cloud_sql.py +++ b/tests/providers/google/cloud/operators/test_cloud_sql.py @@ -22,7 +22,7 @@ import pytest -from airflow.exceptions import AirflowException +from airflow.exceptions import AirflowException, TaskDeferred from airflow.models import Connection from airflow.providers.google.cloud.operators.cloud_sql import ( CloudSQLCloneInstanceOperator, @@ -36,6 +36,8 @@ CloudSQLInstancePatchOperator, CloudSQLPatchInstanceDatabaseOperator, ) +from airflow.providers.google.cloud.triggers.cloud_sql import CloudSQLExportTrigger +from airflow.providers.google.common.consts import GOOGLE_DEFAULT_DEFERRABLE_METHOD_NAME PROJECT_ID = os.environ.get("PROJECT_ID", "project-id") INSTANCE_NAME = os.environ.get("INSTANCE_NAME", "test-name") @@ -669,6 +671,39 @@ def test_instance_export_missing_project_id(self, mock_hook): ) assert result + @mock.patch("airflow.providers.google.cloud.operators.cloud_sql.CloudSQLHook") + @mock.patch("airflow.providers.google.cloud.triggers.cloud_sql.CloudSQLAsyncHook") + def test_execute_call_defer_method(self, mock_trigger_hook, mock_hook): + operator = CloudSQLExportInstanceOperator( + task_id="test_task", + instance=INSTANCE_NAME, + body=EXPORT_BODY, + deferrable=True, + ) + + with pytest.raises(TaskDeferred) as exc: + operator.execute(mock.MagicMock()) + + mock_hook.return_value.export_instance.assert_called_once() + + mock_hook.return_value.get_operation.assert_not_called() + assert isinstance(exc.value.trigger, CloudSQLExportTrigger) + assert exc.value.method_name == GOOGLE_DEFAULT_DEFERRABLE_METHOD_NAME + + def test_async_execute_should_should_throw_exception(self): + """Tests that an AirflowException is raised in case of error event""" + + op = CloudSQLExportInstanceOperator( + task_id="test_task", + instance=INSTANCE_NAME, + body=EXPORT_BODY, + deferrable=True, + ) + with pytest.raises(AirflowException): + op.execute_complete( + context=mock.MagicMock(), event={"status": "error", "message": "test failure message"} + ) + @mock.patch("airflow.providers.google.cloud.operators.cloud_sql.CloudSQLHook") def test_instance_import(self, mock_hook): mock_hook.return_value.export_instance.return_value = True diff --git a/tests/providers/google/cloud/triggers/test_cloud_sql.py b/tests/providers/google/cloud/triggers/test_cloud_sql.py new file mode 100644 index 0000000000000..c7cbb2046ceaf --- /dev/null +++ b/tests/providers/google/cloud/triggers/test_cloud_sql.py @@ -0,0 +1,150 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import asyncio +import logging +from unittest import mock as async_mock + +import pytest + +from airflow.providers.google.cloud.triggers.cloud_sql import CloudSQLExportTrigger +from airflow.triggers.base import TriggerEvent + +CLASSPATH = "airflow.providers.google.cloud.triggers.cloud_sql.CloudSQLExportTrigger" +TASK_ID = "test_task" +TEST_POLL_INTERVAL = 10 +TEST_GCP_CONN_ID = "test-project" +HOOK_STR = "airflow.providers.google.cloud.hooks.cloud_sql.{}" +PROJECT_ID = "test_project_id" +OPERATION_NAME = "test_operation_name" +OPERATION_URL = ( + f"https://sqladmin.googleapis.com/sql/v1beta4/projects/{PROJECT_ID}/operations/{OPERATION_NAME}" +) + + +@pytest.fixture +def trigger(): + return CloudSQLExportTrigger( + operation_name=OPERATION_NAME, + project_id=PROJECT_ID, + impersonation_chain=None, + gcp_conn_id=TEST_GCP_CONN_ID, + poke_interval=TEST_POLL_INTERVAL, + ) + + +class TestCloudSQLExportTrigger: + def test_async_export_trigger_serialization_should_execute_successfully(self, trigger): + """ + Asserts that the CloudSQLExportTrigger correctly serializes its arguments + and classpath. + """ + classpath, kwargs = trigger.serialize() + assert classpath == CLASSPATH + assert kwargs == { + "operation_name": OPERATION_NAME, + "project_id": PROJECT_ID, + "impersonation_chain": None, + "gcp_conn_id": TEST_GCP_CONN_ID, + "poke_interval": TEST_POLL_INTERVAL, + } + + @pytest.mark.asyncio + @async_mock.patch(HOOK_STR.format("CloudSQLAsyncHook.get_operation")) + async def test_async_export_trigger_on_success_should_execute_successfully( + self, mock_get_operation, trigger + ): + """ + Tests the CloudSQLExportTrigger only fires once the job execution reaches a successful state. + """ + mock_get_operation.return_value = { + "status": "DONE", + "name": OPERATION_NAME, + } + generator = trigger.run() + actual = await generator.asend(None) + assert ( + TriggerEvent( + { + "operation_name": OPERATION_NAME, + "status": "success", + } + ) + == actual + ) + + @pytest.mark.asyncio + @async_mock.patch(HOOK_STR.format("CloudSQLAsyncHook.get_operation")) + async def test_async_export_trigger_running_should_execute_successfully( + self, mock_get_operation, trigger, caplog + ): + """ + Test that CloudSQLExportTrigger does not fire while a job is still running. + """ + + mock_get_operation.return_value = { + "status": "RUNNING", + "name": OPERATION_NAME, + } + caplog.set_level(logging.INFO) + task = asyncio.create_task(trigger.run().__anext__()) + await asyncio.sleep(0.5) + + # TriggerEvent was not returned + assert task.done() is False + + assert f"Operation status is RUNNING, sleeping for {TEST_POLL_INTERVAL} seconds." in caplog.text + + # Prevents error when task is destroyed while in "pending" state + asyncio.get_event_loop().stop() + + @pytest.mark.asyncio + @async_mock.patch(HOOK_STR.format("CloudSQLAsyncHook.get_operation")) + async def test_async_export_trigger_error_should_execute_successfully(self, mock_get_operation, trigger): + """ + Test that CloudSQLExportTrigger fires the correct event in case of an error. + """ + mock_get_operation.return_value = { + "status": "DONE", + "name": OPERATION_NAME, + "error": {"message": "test_error"}, + } + + expected_event = { + "operation_name": OPERATION_NAME, + "status": "error", + "message": "test_error", + } + + generator = trigger.run() + actual = await generator.asend(None) + assert TriggerEvent(expected_event) == actual + + @pytest.mark.asyncio + @async_mock.patch(HOOK_STR.format("CloudSQLAsyncHook.get_operation")) + async def test_async_export_trigger_exception_should_execute_successfully( + self, mock_get_operation, trigger + ): + """ + Test that CloudSQLExportTrigger fires the correct event in case of an error. + """ + mock_get_operation.side_effect = Exception("Test exception") + + generator = trigger.run() + actual = await generator.asend(None) + assert TriggerEvent({"status": "failed", "message": "Test exception"}) == actual diff --git a/tests/system/providers/google/cloud/cloud_sql/example_cloud_sql_deferrable.py b/tests/system/providers/google/cloud/cloud_sql/example_cloud_sql_deferrable.py new file mode 100644 index 0000000000000..7859d938702d1 --- /dev/null +++ b/tests/system/providers/google/cloud/cloud_sql/example_cloud_sql_deferrable.py @@ -0,0 +1,184 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Example Airflow DAG that creates, patches and deletes a Cloud SQL instance, and also +creates, patches and deletes a database inside the instance, in Google Cloud. + +This DAG relies on the following OS environment variables +https://airflow.apache.org/concepts.html#variables +* GCP_PROJECT_ID - Google Cloud project for the Cloud SQL instance. +* INSTANCE_NAME - Name of the Cloud SQL instance. +* DB_NAME - Name of the database inside a Cloud SQL instance. +""" +from __future__ import annotations + +import os +from datetime import datetime +from urllib.parse import urlsplit + +from airflow import models +from airflow.models.xcom_arg import XComArg +from airflow.providers.google.cloud.operators.cloud_sql import ( + CloudSQLCreateInstanceDatabaseOperator, + CloudSQLCreateInstanceOperator, + CloudSQLDeleteInstanceDatabaseOperator, + CloudSQLDeleteInstanceOperator, + CloudSQLExportInstanceOperator, +) +from airflow.providers.google.cloud.operators.gcs import ( + GCSBucketCreateAclEntryOperator, + GCSCreateBucketOperator, + GCSDeleteBucketOperator, +) +from airflow.utils.trigger_rule import TriggerRule + +ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID") +PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT") +DAG_ID = "cloudsql-def" + +INSTANCE_NAME = f"{DAG_ID}-{ENV_ID}-instance" +DB_NAME = f"{DAG_ID}-{ENV_ID}-db" + +BUCKET_NAME = f"{DAG_ID}_{ENV_ID}_bucket" +FILE_NAME = f"{DAG_ID}_{ENV_ID}_exportImportTestFile" +FILE_URI = f"gs://{BUCKET_NAME}/{FILE_NAME}" + +# Bodies below represent Cloud SQL instance resources: +# https://cloud.google.com/sql/docs/mysql/admin-api/v1beta4/instances + +body = { + "name": INSTANCE_NAME, + "settings": { + "tier": "db-n1-standard-1", + "backupConfiguration": {"binaryLogEnabled": True, "enabled": True, "startTime": "05:00"}, + "activationPolicy": "ALWAYS", + "dataDiskSizeGb": 30, + "dataDiskType": "PD_SSD", + "databaseFlags": [], + "ipConfiguration": { + "ipv4Enabled": True, + "requireSsl": True, + }, + "locationPreference": {"zone": "europe-west4-a"}, + "maintenanceWindow": {"hour": 5, "day": 7, "updateTrack": "canary"}, + "pricingPlan": "PER_USE", + "replicationType": "ASYNCHRONOUS", + "storageAutoResize": True, + "storageAutoResizeLimit": 0, + "userLabels": {"my-key": "my-value"}, + }, + "databaseVersion": "MYSQL_5_7", + "region": "europe-west4", +} + +export_body = { + "exportContext": { + "fileType": "sql", + "uri": FILE_URI, + "sqlExportOptions": {"schemaOnly": False}, + "offload": True, + } +} + +db_create_body = {"instance": INSTANCE_NAME, "name": DB_NAME, "project": PROJECT_ID} + + +with models.DAG( + DAG_ID, + schedule="@once", + start_date=datetime(2021, 1, 1), + catchup=False, + tags=["example", "cloud_sql"], +) as dag: + create_bucket = GCSCreateBucketOperator(task_id="create_bucket", bucket_name=BUCKET_NAME) + + sql_instance_create_task = CloudSQLCreateInstanceOperator( + body=body, instance=INSTANCE_NAME, task_id="sql_instance_create_task" + ) + + sql_db_create_task = CloudSQLCreateInstanceDatabaseOperator( + body=db_create_body, instance=INSTANCE_NAME, task_id="sql_db_create_task" + ) + + file_url_split = urlsplit(FILE_URI) + + # For export & import to work we need to add the Cloud SQL instance's Service Account + # write access to the destination GCS bucket. + service_account_email = XComArg(sql_instance_create_task, key="service_account_email") + + sql_gcp_add_bucket_permission_task = GCSBucketCreateAclEntryOperator( + entity=f"user-{service_account_email}", + role="WRITER", + bucket=file_url_split[1], # netloc (bucket) + task_id="sql_gcp_add_bucket_permission_task", + ) + + # [START howto_operator_cloudsql_export_async] + sql_export_task = CloudSQLExportInstanceOperator( + body=export_body, + instance=INSTANCE_NAME, + task_id="sql_export_task", + deferrable=True, + ) + # [END howto_operator_cloudsql_export_async] + + sql_db_delete_task = CloudSQLDeleteInstanceDatabaseOperator( + instance=INSTANCE_NAME, database=DB_NAME, task_id="sql_db_delete_task" + ) + sql_db_delete_task.trigger_rule = TriggerRule.ALL_DONE + + sql_instance_delete_task = CloudSQLDeleteInstanceOperator( + instance=INSTANCE_NAME, task_id="sql_instance_delete_task" + ) + sql_instance_delete_task.trigger_rule = TriggerRule.ALL_DONE + + delete_bucket = GCSDeleteBucketOperator( + task_id="delete_bucket", bucket_name=BUCKET_NAME, trigger_rule=TriggerRule.ALL_DONE + ) + + ( + # TEST SETUP + create_bucket + # TEST BODY + >> sql_instance_create_task + >> sql_db_create_task + >> sql_gcp_add_bucket_permission_task + >> sql_export_task + >> sql_db_delete_task + >> sql_instance_delete_task + # TEST TEARDOWN + >> delete_bucket + ) + + # Task dependencies created via `XComArgs`: + # sql_instance_create_task >> sql_gcp_add_bucket_permission_task + # sql_instance_create_task >> sql_gcp_add_object_permission_task + + # ### Everything below this line is not part of example ### + # ### Just for system tests purpose ### + from tests.system.utils.watcher import watcher + + # This test needs watcher in order to properly mark success/failure + # when "tearDown" task with trigger rule is part of the DAG + list(dag.tasks) >> watcher() + + +from tests.system.utils import get_test_run # noqa: E402 + +# Needed to run the example DAG with pytest (see: tests/system/README.md#run_via_pytest) +test_run = get_test_run(dag)