diff --git a/airflow/providers/google/cloud/hooks/pubsub.py b/airflow/providers/google/cloud/hooks/pubsub.py index abb6640516e52..0e33f22ebf59c 100644 --- a/airflow/providers/google/cloud/hooks/pubsub.py +++ b/airflow/providers/google/cloud/hooks/pubsub.py @@ -28,7 +28,7 @@ import warnings from base64 import b64decode from functools import cached_property -from typing import Sequence +from typing import Any, Sequence from uuid import uuid4 from google.api_core.exceptions import AlreadyExists, GoogleAPICallError @@ -45,11 +45,16 @@ ReceivedMessage, RetryPolicy, ) +from google.pubsub_v1.services.subscriber.async_client import SubscriberAsyncClient from googleapiclient.errors import HttpError from airflow.exceptions import AirflowProviderDeprecationWarning from airflow.providers.google.common.consts import CLIENT_INFO -from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID, GoogleBaseHook +from airflow.providers.google.common.hooks.base_google import ( + PROVIDE_PROJECT_ID, + GoogleBaseAsyncHook, + GoogleBaseHook, +) from airflow.version import version @@ -496,7 +501,6 @@ def pull( self.log.info("Pulling max %d messages from subscription (path) %s", max_messages, subscription_path) try: - response = subscriber.pull( request={ "subscription": subscription_path, @@ -569,3 +573,133 @@ def acknowledge( ) self.log.info("Acknowledged ack_ids from subscription (path) %s", subscription_path) + + +class PubSubAsyncHook(GoogleBaseAsyncHook): + """Class to get asynchronous hook for Google Cloud PubSub.""" + + sync_hook_class = PubSubHook + + def __init__(self, project_id: str | None = None, **kwargs: Any): + super().__init__(**kwargs) + self.project_id = project_id + self._client: SubscriberAsyncClient | None = None + + async def _get_subscriber_client(self) -> SubscriberAsyncClient: + """ + Returns async connection to the Google PubSub + :return: Google Pub/Sub asynchronous client. + """ + if not self._client: + credentials = (await self.get_sync_hook()).get_credentials() + self._client = SubscriberAsyncClient(credentials=credentials, client_info=CLIENT_INFO) + return self._client + + @GoogleBaseHook.fallback_to_default_project_id + async def acknowledge( + self, + subscription: str, + project_id: str, + ack_ids: list[str] | None = None, + messages: list[ReceivedMessage] | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), + ) -> None: + """ + Acknowledges the messages associated with the ``ack_ids`` from Pub/Sub subscription. + + :param subscription: the Pub/Sub subscription name to delete; do not + include the 'projects/{project}/topics/' prefix. + :param ack_ids: List of ReceivedMessage ackIds from a previous pull response. + Mutually exclusive with ``messages`` argument. + :param messages: List of ReceivedMessage objects to acknowledge. + Mutually exclusive with ``ack_ids`` argument. + :param project_id: Optional, the Google Cloud project name or ID in which to create the topic + If set to None or missing, the default project_id from the Google Cloud connection is used. + :param retry: (Optional) A retry object used to retry requests. + If None is specified, requests will not be retried. + :param timeout: (Optional) The amount of time, in seconds, to wait for the request + to complete. Note that if retry is specified, the timeout applies to each + individual attempt. + :param metadata: (Optional) Additional metadata that is provided to the method. + """ + subscriber = await self._get_subscriber_client() + if ack_ids is not None and messages is None: + pass # use ack_ids as is + elif ack_ids is None and messages is not None: + ack_ids = [message.ack_id for message in messages] # extract ack_ids from messages + else: + raise ValueError("One and only one of 'ack_ids' and 'messages' arguments have to be provided") + + subscription_path = f"projects/{project_id}/subscriptions/{subscription}" + self.log.info("Acknowledging %d ack_ids from subscription (path) %s", len(ack_ids), subscription_path) + + try: + await subscriber.acknowledge( + request={"subscription": subscription_path, "ack_ids": ack_ids}, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + except (HttpError, GoogleAPICallError) as e: + raise PubSubException( + f"Error acknowledging {len(ack_ids)} messages pulled from subscription {subscription_path}", + e, + ) + self.log.info("Acknowledged ack_ids from subscription (path) %s", subscription_path) + + @GoogleBaseHook.fallback_to_default_project_id + async def pull( + self, + subscription: str, + max_messages: int, + project_id: str = PROVIDE_PROJECT_ID, + return_immediately: bool = False, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), + ) -> list[ReceivedMessage]: + """ + Pulls up to ``max_messages`` messages from Pub/Sub subscription. + + :param subscription: the Pub/Sub subscription name to pull from; do not + include the 'projects/{project}/topics/' prefix. + :param max_messages: The maximum number of messages to return from + the Pub/Sub API. + :param project_id: Optional, the Google Cloud project ID where the subscription exists. + If set to None or missing, the default project_id from the Google Cloud connection is used. + :param return_immediately: If set, the Pub/Sub API will immediately + return if no messages are available. Otherwise, the request will + block for an undisclosed, but bounded period of time + :param retry: (Optional) A retry object used to retry requests. + If None is specified, requests will not be retried. + :param timeout: (Optional) The amount of time, in seconds, to wait for the request + to complete. Note that if retry is specified, the timeout applies to each + individual attempt. + :param metadata: (Optional) Additional metadata that is provided to the method. + :return: A list of Pub/Sub ReceivedMessage objects each containing + an ``ackId`` property and a ``message`` property, which includes + the base64-encoded message content. See + https://cloud.google.com/pubsub/docs/reference/rpc/google.pubsub.v1#google.pubsub.v1.ReceivedMessage + """ + subscriber = await self._get_subscriber_client() + subscription_path = f"projects/{project_id}/subscriptions/{subscription}" + self.log.info("Pulling max %d messages from subscription (path) %s", max_messages, subscription_path) + + try: + response = await subscriber.pull( + request={ + "subscription": subscription_path, + "max_messages": max_messages, + "return_immediately": return_immediately, + }, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + result = getattr(response, "received_messages", []) + self.log.info("Pulled %d messages from subscription (path) %s", len(result), subscription_path) + return result + except (HttpError, GoogleAPICallError) as e: + raise PubSubException(f"Error pulling messages from subscription {subscription_path}", e) diff --git a/airflow/providers/google/cloud/sensors/pubsub.py b/airflow/providers/google/cloud/sensors/pubsub.py index dfe14b542b61d..2e03b3669d3dc 100644 --- a/airflow/providers/google/cloud/sensors/pubsub.py +++ b/airflow/providers/google/cloud/sensors/pubsub.py @@ -18,11 +18,14 @@ """This module contains a Google PubSub sensor.""" from __future__ import annotations +from datetime import timedelta from typing import TYPE_CHECKING, Any, Callable, Sequence from google.cloud.pubsub_v1.types import ReceivedMessage +from airflow.exceptions import AirflowException from airflow.providers.google.cloud.hooks.pubsub import PubSubHook +from airflow.providers.google.cloud.triggers.pubsub import PubsubPullTrigger from airflow.sensors.base import BaseSensorOperator if TYPE_CHECKING: @@ -79,6 +82,7 @@ class PubSubPullSensor(BaseSensorOperator): If set as a sequence, the identities from the list must grant Service Account Token Creator IAM role to the directly preceding identity, with first account from the list granting this role to the originating account (templated). + :param deferrable: Run sensor in deferrable mode """ template_fields: Sequence[str] = ( @@ -98,6 +102,8 @@ def __init__( gcp_conn_id: str = "google_cloud_default", messages_callback: Callable[[list[ReceivedMessage], Context], Any] | None = None, impersonation_chain: str | Sequence[str] | None = None, + poke_interval: float = 10.0, + deferrable: bool = False, **kwargs, ) -> None: @@ -109,14 +115,10 @@ def __init__( self.ack_messages = ack_messages self.messages_callback = messages_callback self.impersonation_chain = impersonation_chain - + self.deferrable = deferrable + self.poke_interval = poke_interval self._return_value = None - def execute(self, context: Context) -> Any: - """Overridden to allow messages to be passed.""" - super().execute(context) - return self._return_value - def poke(self, context: Context) -> bool: hook = PubSubHook( gcp_conn_id=self.gcp_conn_id, @@ -143,6 +145,41 @@ def poke(self, context: Context) -> bool: return bool(pulled_messages) + def execute(self, context: Context) -> None: + """ + Airflow runs this method on the worker and defers using the triggers + if deferrable is set to True. + """ + if not self.deferrable: + super().execute(context) + return self._return_value + else: + self.defer( + timeout=timedelta(seconds=self.timeout), + trigger=PubsubPullTrigger( + project_id=self.project_id, + subscription=self.subscription, + max_messages=self.max_messages, + ack_messages=self.ack_messages, + messages_callback=self.messages_callback, + poke_interval=self.poke_interval, + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ), + method_name="execute_complete", + ) + + def execute_complete(self, context: dict[str, Any], event: dict[str, str | list[str]]) -> str | list[str]: + """ + Callback for when the trigger fires; returns immediately. + Relies on trigger to throw a success event. + """ + if event["status"] == "success": + self.log.info("Sensor pulls messages: %s", event["message"]) + return event["message"] + self.log.info("Sensor failed: %s", event["message"]) + raise AirflowException(event["message"]) + def _default_message_callback( self, pulled_messages: list[ReceivedMessage], diff --git a/airflow/providers/google/cloud/triggers/pubsub.py b/airflow/providers/google/cloud/triggers/pubsub.py new file mode 100644 index 0000000000000..40c43f7cb79dd --- /dev/null +++ b/airflow/providers/google/cloud/triggers/pubsub.py @@ -0,0 +1,126 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""This module contains Google Cloud Pubsub triggers.""" +from __future__ import annotations + +import asyncio +from typing import TYPE_CHECKING, Any, AsyncIterator, Callable, Sequence + +from google.cloud.pubsub_v1.types import ReceivedMessage + +from airflow.providers.google.cloud.hooks.pubsub import PubSubAsyncHook +from airflow.triggers.base import BaseTrigger, TriggerEvent + +if TYPE_CHECKING: + from airflow.utils.context import Context + + +class PubsubPullTrigger(BaseTrigger): + """ + Initialize the Pubsub Pull Trigger with needed parameters. + + :param project_id: the Google Cloud project ID for the subscription (templated) + :param subscription: the Pub/Sub subscription name. Do not include the full subscription path. + :param max_messages: The maximum number of messages to retrieve per + PubSub pull request + :param ack_messages: If True, each message will be acknowledged + immediately rather than by any downstream tasks + :param gcp_conn_id: Reference to google cloud connection id + :param messages_callback: (Optional) Callback to process received messages. + It's return value will be saved to XCom. + If you are pulling large messages, you probably want to provide a custom callback. + If not provided, the default implementation will convert `ReceivedMessage` objects + into JSON-serializable dicts using `google.protobuf.json_format.MessageToDict` function. + :param poke_interval: polling period in seconds to check for the status + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + """ + + def __init__( + self, + project_id: str, + subscription: str, + max_messages: int, + ack_messages: bool, + gcp_conn_id: str, + messages_callback: Callable[[list[ReceivedMessage], Context], Any] | None = None, + poke_interval: float = 10.0, + impersonation_chain: str | Sequence[str] | None = None, + ): + super().__init__() + self.project_id = project_id + self.subscription = subscription + self.max_messages = max_messages + self.ack_messages = ack_messages + self.messages_callback = messages_callback + self.poke_interval = poke_interval + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + self.hook = PubSubAsyncHook() + + def serialize(self) -> tuple[str, dict[str, Any]]: + """Serializes PubsubPullTrigger arguments and classpath.""" + return ( + "airflow.providers.google.cloud.triggers.pubsub.PubsubPullTrigger", + { + "project_id": self.project_id, + "subscription": self.subscription, + "max_messages": self.max_messages, + "ack_messages": self.ack_messages, + "messages_callback": self.messages_callback, + "poke_interval": self.poke_interval, + "gcp_conn_id": self.gcp_conn_id, + "impersonation_chain": self.impersonation_chain, + }, + ) + + async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override] + try: + pulled_messages = None + while True: + if pulled_messages: + if self.ack_messages: + await self.message_acknowledgement(pulled_messages) + yield TriggerEvent({"status": "success", "message": pulled_messages}) + else: + yield TriggerEvent({"status": "success", "message": pulled_messages}) + else: + pulled_messages = await self.hook.pull( + project_id=self.project_id, + subscription=self.subscription, + max_messages=self.max_messages, + return_immediately=True, + ) + self.log.info("Sleeping for %s seconds.", self.poke_interval) + await asyncio.sleep(self.poke_interval) + except Exception as e: + yield TriggerEvent({"status": "error", "message": str(e)}) + return + + async def message_acknowledgement(self, pulled_messages): + await self.hook.acknowledge( + project_id=self.project_id, + subscription=self.subscription, + messages=pulled_messages, + ) + self.log.info("Acknowledged ack_ids from subscription %s", self.subscription) diff --git a/airflow/providers/google/provider.yaml b/airflow/providers/google/provider.yaml index 16f939d0a02b6..64813045844d7 100644 --- a/airflow/providers/google/provider.yaml +++ b/airflow/providers/google/provider.yaml @@ -861,6 +861,9 @@ triggers: - integration-name: Google Machine Learning Engine python-modules: - airflow.providers.google.cloud.triggers.mlengine + - integration-name: Google Cloud Pub/Sub + python-modules: + - airflow.providers.google.cloud.triggers.pubsub transfers: - source-integration-name: Presto diff --git a/docs/apache-airflow-providers-google/operators/cloud/pubsub.rst b/docs/apache-airflow-providers-google/operators/cloud/pubsub.rst index 7ab737e4ea694..b3c30bc54e8d0 100644 --- a/docs/apache-airflow-providers-google/operators/cloud/pubsub.rst +++ b/docs/apache-airflow-providers-google/operators/cloud/pubsub.rst @@ -88,6 +88,14 @@ and pass them through XCom. :start-after: [START howto_operator_gcp_pubsub_pull_message_with_sensor] :end-before: [END howto_operator_gcp_pubsub_pull_message_with_sensor] +Also for this action you can use sensor in the deferrable mode: + +.. exampleinclude:: /../../tests/system/providers/google/cloud/pubsub/example_pubsub_deferrable.py + :language: python + :dedent: 4 + :start-after: [START howto_operator_gcp_pubsub_pull_message_with_async_sensor] + :end-before: [END howto_operator_gcp_pubsub_pull_message_with_async_sensor] + .. exampleinclude:: /../../tests/system/providers/google/cloud/pubsub/example_pubsub.py :language: python :start-after: [START howto_operator_gcp_pubsub_pull_message_with_operator] diff --git a/tests/providers/google/cloud/hooks/test_pubsub.py b/tests/providers/google/cloud/hooks/test_pubsub.py index cad0451bc68e0..fe25df6562d71 100644 --- a/tests/providers/google/cloud/hooks/test_pubsub.py +++ b/tests/providers/google/cloud/hooks/test_pubsub.py @@ -28,7 +28,7 @@ from google.cloud.pubsub_v1.types import ReceivedMessage from googleapiclient.errors import HttpError -from airflow.providers.google.cloud.hooks.pubsub import PubSubException, PubSubHook +from airflow.providers.google.cloud.hooks.pubsub import PubSubAsyncHook, PubSubException, PubSubHook from airflow.providers.google.common.consts import CLIENT_INFO from airflow.version import version @@ -576,3 +576,52 @@ def test_messages_validation_negative(self, messages, error_message): with pytest.raises(PubSubException) as ctx: PubSubHook._validate_messages(messages) assert str(ctx.value) == error_message + + +class TestPubSubAsyncHook: + @pytest.fixture + def hook(self): + return PubSubAsyncHook() + + @pytest.mark.asyncio + @mock.patch("airflow.providers.google.cloud.hooks.pubsub.PubSubAsyncHook._get_subscriber_client") + async def test_pull(self, mock_subscriber_client, hook): + client = mock_subscriber_client.return_value + + await hook.pull( + project_id=TEST_PROJECT, subscription=TEST_SUBSCRIPTION, max_messages=10, return_immediately=False + ) + + mock_subscriber_client.assert_called_once() + client.pull.assert_called_once_with( + request=dict( + subscription=EXPANDED_SUBSCRIPTION, + max_messages=10, + return_immediately=False, + ), + retry=DEFAULT, + timeout=None, + metadata=(), + ) + + @pytest.mark.asyncio + @mock.patch("airflow.providers.google.cloud.hooks.pubsub.PubSubAsyncHook._get_subscriber_client") + async def test_acknowledge(self, mock_subscriber_client, hook): + client = mock_subscriber_client.return_value + + await hook.acknowledge( + project_id=TEST_PROJECT, + subscription=TEST_SUBSCRIPTION, + messages=_generate_messages(3), + ) + + mock_subscriber_client.assert_called_once() + client.acknowledge.assert_called_once_with( + request=dict( + subscription=EXPANDED_SUBSCRIPTION, + ack_ids=["1", "2", "3"], + ), + retry=DEFAULT, + timeout=None, + metadata=(), + ) diff --git a/tests/providers/google/cloud/sensors/test_pubsub.py b/tests/providers/google/cloud/sensors/test_pubsub.py index 952758578c3e7..88fa3e296f8d5 100644 --- a/tests/providers/google/cloud/sensors/test_pubsub.py +++ b/tests/providers/google/cloud/sensors/test_pubsub.py @@ -23,8 +23,9 @@ import pytest from google.cloud.pubsub_v1.types import ReceivedMessage -from airflow.exceptions import AirflowSensorTimeout +from airflow.exceptions import AirflowException, AirflowSensorTimeout, TaskDeferred from airflow.providers.google.cloud.sensors.pubsub import PubSubPullSensor +from airflow.providers.google.cloud.triggers.pubsub import PubsubPullTrigger TASK_ID = "test-task-id" TEST_PROJECT = "test-project" @@ -155,3 +156,50 @@ def messages_callback( messages_callback.assert_called_once() assert response == messages_callback_return_value + + def test_pubsub_pull_sensor_async(self): + """ + Asserts that a task is deferred and a PubsubPullTrigger will be fired + when the PubSubPullSensor is executed. + """ + task = PubSubPullSensor( + task_id="test_task_id", + ack_messages=True, + project_id=TEST_PROJECT, + subscription=TEST_SUBSCRIPTION, + deferrable=True, + ) + with pytest.raises(TaskDeferred) as exc: + task.execute(context={}) + assert isinstance(exc.value.trigger, PubsubPullTrigger), "Trigger is not a PubsubPullTrigger" + + def test_pubsub_pull_sensor_async_execute_should_throw_exception(self): + """Tests that an AirflowException is raised in case of error event""" + + operator = PubSubPullSensor( + task_id="test_task", + ack_messages=True, + project_id=TEST_PROJECT, + subscription=TEST_SUBSCRIPTION, + deferrable=True, + ) + + with pytest.raises(AirflowException): + operator.execute_complete( + context=mock.MagicMock(), event={"status": "error", "message": "test failure message"} + ) + + def test_pubsub_pull_sensor_async_execute_complete(self): + """Asserts that logging occurs as expected""" + operator = PubSubPullSensor( + task_id="test_task", + ack_messages=True, + project_id=TEST_PROJECT, + subscription=TEST_SUBSCRIPTION, + deferrable=True, + ) + + test_message = "test" + with mock.patch.object(operator.log, "info") as mock_log_info: + operator.execute_complete(context={}, event={"status": "success", "message": test_message}) + mock_log_info.assert_called_with("Sensor pulls messages: %s", test_message) diff --git a/tests/providers/google/cloud/triggers/test_pubsub.py b/tests/providers/google/cloud/triggers/test_pubsub.py new file mode 100644 index 0000000000000..d2294eb61414b --- /dev/null +++ b/tests/providers/google/cloud/triggers/test_pubsub.py @@ -0,0 +1,61 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import pytest + +from airflow.providers.google.cloud.triggers.pubsub import PubsubPullTrigger + +TEST_POLL_INTERVAL = 10 +TEST_GCP_CONN_ID = "google_cloud_default" +PROJECT_ID = "test_project_id" +MAX_MESSAGES = 5 +ACK_MESSAGES = True + + +@pytest.fixture +def trigger(): + return PubsubPullTrigger( + project_id=PROJECT_ID, + subscription="subscription", + max_messages=MAX_MESSAGES, + ack_messages=ACK_MESSAGES, + messages_callback=None, + poke_interval=TEST_POLL_INTERVAL, + gcp_conn_id=TEST_GCP_CONN_ID, + impersonation_chain=None, + ) + + +class TestPubsubPullTrigger: + def test_async_pubsub_pull_trigger_serialization_should_execute_successfully(self, trigger): + """ + Asserts that the PubsubPullTrigger correctly serializes its arguments + and classpath. + """ + classpath, kwargs = trigger.serialize() + assert classpath == "airflow.providers.google.cloud.triggers.pubsub.PubsubPullTrigger" + assert kwargs == { + "project_id": PROJECT_ID, + "subscription": "subscription", + "max_messages": MAX_MESSAGES, + "ack_messages": ACK_MESSAGES, + "messages_callback": None, + "poke_interval": TEST_POLL_INTERVAL, + "gcp_conn_id": TEST_GCP_CONN_ID, + "impersonation_chain": None, + } diff --git a/tests/system/providers/google/cloud/pubsub/example_pubsub.py b/tests/system/providers/google/cloud/pubsub/example_pubsub.py index e7b0ee33e3b04..81687f97bb90c 100644 --- a/tests/system/providers/google/cloud/pubsub/example_pubsub.py +++ b/tests/system/providers/google/cloud/pubsub/example_pubsub.py @@ -37,7 +37,7 @@ from airflow.utils.trigger_rule import TriggerRule ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID") -PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT", "your-project-id") +PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT") DAG_ID = "pubsub" diff --git a/tests/system/providers/google/cloud/pubsub/example_pubsub_deferrable.py b/tests/system/providers/google/cloud/pubsub/example_pubsub_deferrable.py new file mode 100644 index 0000000000000..3f5b9a2b5dd2e --- /dev/null +++ b/tests/system/providers/google/cloud/pubsub/example_pubsub_deferrable.py @@ -0,0 +1,113 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Example Airflow DAG that uses Google PubSub services. +""" +from __future__ import annotations + +import os +from datetime import datetime + +from airflow import models +from airflow.providers.google.cloud.operators.pubsub import ( + PubSubCreateSubscriptionOperator, + PubSubCreateTopicOperator, + PubSubDeleteSubscriptionOperator, + PubSubDeleteTopicOperator, + PubSubPublishMessageOperator, +) +from airflow.providers.google.cloud.sensors.pubsub import PubSubPullSensor +from airflow.utils.trigger_rule import TriggerRule + +ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID") +PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT") + +DAG_ID = "pubsub_async" + +TOPIC_ID = f"topic-{DAG_ID}-{ENV_ID}" +MESSAGE = {"data": b"Tool", "attributes": {"name": "wrench", "mass": "1.3kg", "count": "3"}} + + +with models.DAG( + DAG_ID, + schedule="@once", # Override to match your needs + start_date=datetime(2021, 1, 1), + catchup=False, +) as dag: + create_topic = PubSubCreateTopicOperator( + task_id="create_topic", topic=TOPIC_ID, project_id=PROJECT_ID, fail_if_exists=False + ) + + subscribe_task = PubSubCreateSubscriptionOperator( + task_id="subscribe_task", project_id=PROJECT_ID, topic=TOPIC_ID + ) + + publish_task = PubSubPublishMessageOperator( + task_id="publish_task", + project_id=PROJECT_ID, + topic=TOPIC_ID, + messages=[MESSAGE, MESSAGE], + ) + subscription = subscribe_task.output + + # [START howto_operator_gcp_pubsub_pull_message_with_async_sensor] + pull_messages_async = PubSubPullSensor( + task_id="pull_messages_async", + ack_messages=True, + project_id=PROJECT_ID, + subscription=subscription, + deferrable=True, + ) + # [END howto_operator_gcp_pubsub_pull_message_with_async_sensor] + + unsubscribe_task = PubSubDeleteSubscriptionOperator( + task_id="unsubscribe_task", + project_id=PROJECT_ID, + subscription=subscription, + ) + unsubscribe_task.trigger_rule = TriggerRule.ALL_DONE + + delete_topic = PubSubDeleteTopicOperator(task_id="delete_topic", topic=TOPIC_ID, project_id=PROJECT_ID) + delete_topic.trigger_rule = TriggerRule.ALL_DONE + + ( + create_topic + >> subscribe_task + >> publish_task + >> pull_messages_async + >> unsubscribe_task + >> delete_topic + ) + + # Task dependencies created via `XComArgs`: + # subscribe_task >> pull_messages_async + # subscribe_task >> unsubscribe_task + + # ### Everything below this line is not part of example ### + # ### Just for system tests purpose ### + from tests.system.utils.watcher import watcher + + # This test needs watcher in order to properly mark success/failure + # when "tearDown" task with trigger rule is part of the DAG + list(dag.tasks) >> watcher() + + +from tests.system.utils import get_test_run # noqa: E402 + +# Needed to run the example DAG with pytest (see: tests/system/README.md#run_via_pytest) +test_run = get_test_run(dag)