diff --git a/codecov.yml b/codecov.yml new file mode 100644 index 0000000000..2a3ffa6059 --- /dev/null +++ b/codecov.yml @@ -0,0 +1,2 @@ +ignore: + - "cosmos/airflow/_override.py" diff --git a/cosmos/__init__.py b/cosmos/__init__.py index be4f762632..4b73a23d72 100644 --- a/cosmos/__init__.py +++ b/cosmos/__init__.py @@ -9,7 +9,8 @@ from cosmos import settings -__version__ = "1.12.1a1" +__version__ = "1.13.0a1" + if not settings.enable_memory_optimised_imports: from cosmos.airflow.dag import DbtDag diff --git a/cosmos/airflow/_override.py b/cosmos/airflow/_override.py new file mode 100644 index 0000000000..c157292d6c --- /dev/null +++ b/cosmos/airflow/_override.py @@ -0,0 +1,193 @@ +import math +import time +from collections.abc import Callable +from datetime import timedelta + +import pendulum +from airflow.providers.cncf.kubernetes import __version__ as airflow_k8s_provider_version +from airflow.providers.cncf.kubernetes.callbacks import ExecutionMode +from airflow.providers.cncf.kubernetes.utils.pod_manager import PodLoggingStatus, PodManager +from airflow.utils.timezone import utcnow +from kubernetes.client.models.v1_pod import V1Pod +from packaging.version import Version +from pendulum import DateTime +from urllib3.exceptions import HTTPError, TimeoutError + +from cosmos.constants import _K8s_WATCHER_MIN_K8S_PROVIDER_VERSION + + +# This is being added to overcome the issue with the KubernetesPodOperator logs repeating: +# https://github.com/apache/airflow/issues/59366 +# It can be removed once it is fixed in the upstream provider. +class CosmosKubernetesPodManager(PodManager): # type: ignore[misc] + """Create, monitor, and otherwise interact with Kubernetes pods for use with the KubernetesPodOperator.""" + + def fetch_container_logs( # noqa: C901 + self, + pod: V1Pod, + container_name: str, + *, + follow: bool = False, + since_time: DateTime | None = None, + post_termination_timeout: int = 120, + container_name_log_prefix_enabled: bool = True, + log_formatter: Callable[[str, str], str] | None = None, + ) -> PodLoggingStatus: + """ + Follow the logs of container and stream to airflow logging. + + Returns when container exits. + + Between when the pod starts and logs being available, there might be a delay due to CSR not approved + and signed yet. In such situation, ApiException is thrown. This is why we are retrying on this + specific exception. + + :meta private: + """ + + def consume_logs( # noqa: C901 + *, since_time: DateTime | None = None + ) -> tuple[DateTime | None, Exception | None]: + """ + Try to follow container logs until container completes. + + For a long-running container, sometimes the log read may be interrupted + Such errors of this kind are suppressed. + + Returns the last timestamp observed in logs. + """ + # Cosmos implementation difference when compared to proposal to fix the issue in the upstream provider: + # https://github.com/apache/airflow/pull/59372/ + if Version(airflow_k8s_provider_version) >= Version("10.10.0"): + from airflow.providers.cncf.kubernetes.utils.pod_manager import parse_log_line + elif ( + Version(airflow_k8s_provider_version) >= _K8s_WATCHER_MIN_K8S_PROVIDER_VERSION + ): # Successfully tested with Airflow 3.1.0 and K8s provider 10.8.0 and 10.9.0 + parse_log_line = self.parse_log_line + else: + raise ValueError( + f"Unsupported K8s provider version: {airflow_k8s_provider_version}. " + f"Minimum required version is {_K8s_WATCHER_MIN_K8S_PROVIDER_VERSION}" + ) + # Cosmos custom implementation finishes here. + + exception = None + last_captured_timestamp = None + # We timeout connections after 30 minutes because otherwise they can get + # stuck forever. The 30 is somewhat arbitrary. + # As a consequence, a TimeoutError will be raised no more than 30 minutes + # after starting read. + connection_timeout = 60 * 30 + # We set a shorter read timeout because that helps reduce *connection* timeouts + # (since the connection will be restarted periodically). And with read timeout, + # we don't need to worry about either duplicate messages or losing messages; we + # can safely resume from a few seconds later + read_timeout = 60 * 5 + try: + since_seconds = None + if since_time: + try: + since_seconds = math.ceil((pendulum.now() - since_time).total_seconds()) + except TypeError: + self.log.warning( + "Error calculating since_seconds with since_time %s. Using None instead.", + since_time, + ) + logs = self.read_pod_logs( + pod=pod, + container_name=container_name, + timestamps=True, + since_seconds=since_seconds, + follow=follow, + post_termination_timeout=post_termination_timeout, + _request_timeout=(connection_timeout, read_timeout), + ) + message_to_log = None + message_timestamp = None + progress_callback_lines = [] + try: + for raw_line in logs: + line = raw_line.decode("utf-8", errors="backslashreplace") + line_timestamp, message = parse_log_line(line) + if line_timestamp: # detect new log line + if message_to_log is None: # first line in the log + message_to_log = message + message_timestamp = line_timestamp + progress_callback_lines.append(line) + else: # previous log line is complete + for callback in self._callbacks: + callback.progress_callback( + line=message_to_log, + client=self._client, + mode=ExecutionMode.SYNC, + container_name=container_name, + timestamp=message_timestamp, + pod=pod, + ) + self._log_message( + message_to_log, + container_name, + container_name_log_prefix_enabled, + log_formatter, + ) + last_captured_timestamp = message_timestamp + message_to_log = message + message_timestamp = line_timestamp + progress_callback_lines = [line] + else: # continuation of the previous log line + message_to_log = f"{message_to_log}\n{message}" + progress_callback_lines.append(line) + finally: + # log the last line and update the last_captured_timestamp + if message_to_log is not None: + for callback in self._callbacks: + callback.progress_callback( + line=message_to_log, + client=self._client, + mode=ExecutionMode.SYNC, + container_name=container_name, + timestamp=message_timestamp, + pod=pod, + ) + self._log_message( + message_to_log, container_name, container_name_log_prefix_enabled, log_formatter + ) + last_captured_timestamp = message_timestamp + except TimeoutError as e: + # in case of timeout, increment return time by 2 seconds to avoid + # duplicate log entries + if val := (last_captured_timestamp or since_time): + return val.add(seconds=2), e + except HTTPError as e: + exception = e + self._http_error_timestamps = getattr(self, "_http_error_timestamps", []) + self._http_error_timestamps = [ + t for t in self._http_error_timestamps if t > utcnow() - timedelta(seconds=60) + ] + self._http_error_timestamps.append(utcnow()) + # Log only if more than 2 errors occurred in the last 60 seconds + if len(self._http_error_timestamps) > 2: + self.log.exception( + "Reading of logs interrupted for container %r; will retry.", + container_name, + ) + return last_captured_timestamp or since_time, exception + + # note: `read_pod_logs` follows the logs, so we shouldn't necessarily *need* to + # loop as we do here. But in a long-running process we might temporarily lose connectivity. + # So the looping logic is there to let us resume following the logs. + last_log_time = since_time + while True: + last_log_time, exc = consume_logs(since_time=last_log_time) + if not self.container_is_running(pod, container_name=container_name): + return PodLoggingStatus(running=False, last_log_time=last_log_time) + if not follow: + return PodLoggingStatus(running=True, last_log_time=last_log_time) + # a timeout is a normal thing and we ignore it and resume following logs + if not isinstance(exc, TimeoutError): + self.log.warning( + "Pod %s log read interrupted but container %s still running. Logs generated in the last one second might get duplicated.", + pod.metadata.name, + container_name, + ) + time.sleep(1) diff --git a/cosmos/airflow/graph.py b/cosmos/airflow/graph.py index 2cbba41141..156fea563f 100644 --- a/cosmos/airflow/graph.py +++ b/cosmos/airflow/graph.py @@ -202,17 +202,20 @@ def create_test_task_metadata( if node: args_to_override = node.operator_kwargs_to_override + dbt_class = "DbtTest" + watcher_to_test_execution_mode = { + ExecutionMode.WATCHER: ExecutionMode.LOCAL, + ExecutionMode.WATCHER_KUBERNETES: ExecutionMode.KUBERNETES, + } if ( - execution_mode == ExecutionMode.WATCHER - and render_config is not None + render_config is not None and render_config.test_behavior == TestBehavior.AFTER_ALL + and execution_mode in (ExecutionMode.WATCHER, ExecutionMode.WATCHER_KUBERNETES) ): - operator_class = "cosmos.operators.local.DbtTestLocalOperator" + test_execution_mode = watcher_to_test_execution_mode[execution_mode] + operator_class = calculate_operator_class(execution_mode=test_execution_mode, dbt_class=dbt_class) else: - operator_class = calculate_operator_class( - execution_mode=execution_mode, - dbt_class="DbtTest", - ) + operator_class = calculate_operator_class(execution_mode=execution_mode, dbt_class=dbt_class) return TaskMetadata( id=test_task_name, @@ -647,6 +650,7 @@ def _add_watcher_producer_task( tasks_map: dict[str, Any], task_group: TaskGroup | None, render_config: RenderConfig | None = None, + execution_mode: ExecutionMode = ExecutionMode.WATCHER, ) -> BaseOperator: """ Create the producer task for the watcher execution mode and add it to the tasks_map. @@ -665,10 +669,12 @@ def _add_watcher_producer_task( "resource_type:unit_test", ] + class_name = calculate_operator_class(execution_mode, "DbtProducer") + # First, we create the producer task producer_task_metadata = TaskMetadata( id=PRODUCER_WATCHER_TASK_ID, - operator_class="cosmos.operators.watcher.DbtProducerWatcherOperator", + operator_class=class_name, arguments=producer_task_args, ) producer_airflow_task = create_airflow_task(producer_task_metadata, dag, task_group=task_group) @@ -682,13 +688,17 @@ def _add_watcher_dependencies( task_args: dict[str, Any], tasks_map: dict[str, Any], nodes: dict[str, DbtNode] | None = None, -) -> str: +) -> None: """ Iterate through the watcher consumer tasks and: - set the producer task ID in all of them - make the producer task to be the parent of the root dbt nodes, without blocking them from sensing XCom """ for node_id, task_or_taskgroup in tasks_map.items(): + # We do not want to set a dependency between the producer task and itself + if node_id == PRODUCER_WATCHER_TASK_ID: + continue + node_tasks = ( list(task_or_taskgroup.children.values()) if isinstance(task_or_taskgroup, TaskGroup) @@ -701,7 +711,6 @@ def _add_watcher_dependencies( # We only managed to do this in the case of DbtDag. # The way it is implemented is by setting the trigger_rule to "always" for the consumer tasks, and by having the producer task with a high priority_weight. if "DbtDag" in dag.__class__.__name__: - # Is this dbt node a root of the (subset of the) dbt project? # Note: this may happen in one scenarios: # - the dbt node not having any `depends_on` within the user-selected `nodes` @@ -714,12 +723,9 @@ def _add_watcher_dependencies( ] else: always_run_tasks = [task_or_taskgroup] - for task in always_run_tasks: task.trigger_rule = task_args.get("trigger_rule", "always") # type: ignore[attr-defined] - return producer_airflow_task.task_id - def should_create_detached_nodes(render_config: RenderConfig) -> bool: """ @@ -859,7 +865,7 @@ def build_airflow_graph( # noqa: C901 TODO: https://github.com/astronomer/astro if execution_mode == ExecutionMode.AIRFLOW_ASYNC: # This property is only relevant for the setup task, not the other tasks: virtualenv_dir = task_args.pop("virtualenv_dir", None) - elif execution_mode == ExecutionMode.WATCHER: + elif execution_mode in (ExecutionMode.WATCHER, ExecutionMode.WATCHER_KUBERNETES): setup_operator_args = getattr(execution_config, "setup_operator_args", None) or {} # We are intentionally creating the producer task ahead of the consumer tasks # Airflow priority weight is not being respected in multiple versions of the library, including 3.1 @@ -870,6 +876,7 @@ def build_airflow_graph( # noqa: C901 TODO: https://github.com/astronomer/astro tasks_map=tasks_map, task_group=task_group, render_config=render_config, + execution_mode=execution_mode, ) for node_id, node in nodes.items(): @@ -940,7 +947,7 @@ def build_airflow_graph( # noqa: C901 TODO: https://github.com/astronomer/astro create_airflow_task_dependencies(nodes, tasks_map) - if execution_mode == ExecutionMode.WATCHER: + if execution_mode in (ExecutionMode.WATCHER, ExecutionMode.WATCHER_KUBERNETES): setup_operator_args = getattr(execution_config, "setup_operator_args", None) or {} _add_watcher_dependencies( dag=dag, diff --git a/cosmos/constants.py b/cosmos/constants.py index 024297d071..f768155733 100644 --- a/cosmos/constants.py +++ b/cosmos/constants.py @@ -108,6 +108,7 @@ class ExecutionMode(Enum): VIRTUALENV = "virtualenv" AZURE_CONTAINER_INSTANCE = "azure_container_instance" GCP_CLOUD_RUN_JOB = "gcp_cloud_run_job" + WATCHER_KUBERNETES = "watcher_kubernetes" class InvocationMode(Enum): @@ -195,3 +196,5 @@ def _missing_value_(cls, value): # type: ignore TELEMETRY_TIMEOUT = 1.0 _AIRFLOW3_MAJOR_VERSION = 3 + +_K8s_WATCHER_MIN_K8S_PROVIDER_VERSION = Version("10.8.0") diff --git a/cosmos/operators/_watcher/base.py b/cosmos/operators/_watcher/base.py index 2a018a9146..389717a41d 100644 --- a/cosmos/operators/_watcher/base.py +++ b/cosmos/operators/_watcher/base.py @@ -59,8 +59,8 @@ class BaseConsumerSensor(BaseSensorOperator): # type: ignore[misc] def __init__( self, *, + project_dir: str, profile_config: ProfileConfig | None = None, - project_dir: str | None = None, profiles_dir: str | None = None, producer_task_id: str = PRODUCER_WATCHER_TASK_ID, poke_interval: int = 10, diff --git a/cosmos/operators/_watcher/triggerer.py b/cosmos/operators/_watcher/triggerer.py index 3bebd54f71..8de3bfb6e7 100644 --- a/cosmos/operators/_watcher/triggerer.py +++ b/cosmos/operators/_watcher/triggerer.py @@ -82,6 +82,14 @@ def _get_xcom_val() -> Any | None: return await sync_to_async(_get_xcom_val)() async def get_xcom_val(self, key: str) -> Any | None: + self.log.info( + "Trying to retrieve value using XCom key <%s> by task_id <%s>, dag_id <%s>, run_id <%s> and map_index <%s>", + key, + self.producer_task_id, + self.dag_id, + self.run_id, + self.map_index, + ) if AIRFLOW_VERSION < Version("3.0.0"): return await self.get_xcom_val_af2(key) else: @@ -141,7 +149,6 @@ async def run(self) -> AsyncIterator[TriggerEvent]: ) yield TriggerEvent({"status": "failed", "reason": "producer_failed"}) # type: ignore[no-untyped-call] return - # Sleep briefly before re-polling await asyncio.sleep(self.poke_interval) self.log.debug("Polling again for model '%s' status...", self.model_unique_id) diff --git a/cosmos/operators/watcher.py b/cosmos/operators/watcher.py index 0b86c27cdd..29696fadc0 100644 --- a/cosmos/operators/watcher.py +++ b/cosmos/operators/watcher.py @@ -216,8 +216,8 @@ class DbtConsumerWatcherSensor(BaseConsumerSensor, DbtRunLocalOperator): # type def __init__( self, *, + project_dir: str, profile_config: ProfileConfig | None = None, - project_dir: str | None = None, profiles_dir: str | None = None, producer_task_id: str = PRODUCER_WATCHER_TASK_ID, poke_interval: int = 10, diff --git a/cosmos/operators/watcher_kubernetes.py b/cosmos/operators/watcher_kubernetes.py new file mode 100644 index 0000000000..41c4df6248 --- /dev/null +++ b/cosmos/operators/watcher_kubernetes.py @@ -0,0 +1,184 @@ +from __future__ import annotations + +from collections.abc import Callable +from functools import cached_property +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: # pragma: no cover + from pendulum import DateTime + + try: + from airflow.sdk.definitions.context import Context + except ImportError: + from airflow.utils.context import Context # type: ignore[attr-defined] + +import kubernetes.client as k8s +from airflow.exceptions import AirflowException +from airflow.providers.cncf.kubernetes.callbacks import KubernetesPodOperatorCallback, client_type + +try: + from airflow.providers.standard.operators.empty import EmptyOperator +except ImportError: # pragma: no cover + from airflow.operators.empty import EmptyOperator # type: ignore[no-redef] + +from cosmos.airflow._override import CosmosKubernetesPodManager +from cosmos.log import get_logger +from cosmos.operators._watcher.base import BaseConsumerSensor, store_dbt_resource_status_from_log +from cosmos.operators.base import ( + DbtRunMixin, + DbtSeedMixin, + DbtSnapshotMixin, +) +from cosmos.operators.kubernetes import ( + DbtBuildKubernetesOperator, + DbtRunKubernetesOperator, + DbtSourceKubernetesOperator, +) + +logger = get_logger(__name__) + + +# This global variable is currently used to make the task context available to the K8s callback. +# While the callback is set during the operator initialization, the context is only created during the operator's execution. +producer_task_context = None + + +class WatcherKubernetesCallback(KubernetesPodOperatorCallback): # type: ignore[misc] + + @staticmethod + def progress_callback( + *, + line: str, + client: client_type, + mode: str, + container_name: str, + timestamp: DateTime | None, + pod: k8s.V1Pod, + **kwargs: Any, + ) -> None: + """ + Invoke this callback to process pod container logs. + + :param line: the read line of log. + :param client: the Kubernetes client that can be used in the callback. + :param mode: the current execution mode, it's one of (`sync`, `async`). + :param container_name: the name of the container from which the log line was read. + :param timestamp: the timestamp of the log line. + :param pod: the pod from which the log line was read. + """ + if "context" not in kwargs: + # This global variable is used to make the task context available to the K8s callback. + # While the callback is set during the operator initialization, the context is only created during the operator's execution. + kwargs["context"] = producer_task_context + store_dbt_resource_status_from_log(line, kwargs) + + +class DbtProducerWatcherKubernetesOperator(DbtBuildKubernetesOperator): + + template_fields: tuple[str, ...] = tuple(DbtBuildKubernetesOperator.template_fields) + ("deferrable",) + _process_log_line_callable: Callable[[str, dict[str, Any]], None] | None = store_dbt_resource_status_from_log + + def __init__(self, *args: Any, **kwargs: Any) -> None: + task_id = kwargs.pop("task_id", "dbt_producer_watcher_operator") + + # Disable retries on producer task + default_args = dict(kwargs.get("default_args", {}) or {}) + default_args["retries"] = 0 + kwargs["default_args"] = default_args + kwargs["retries"] = 0 + + super().__init__(task_id=task_id, *args, callbacks=WatcherKubernetesCallback, **kwargs) + self.dbt_cmd_flags += ["--log-format", "json"] + + @cached_property + def pod_manager(self) -> CosmosKubernetesPodManager: + return CosmosKubernetesPodManager(kube_client=self.client, callbacks=self.callbacks) + + def execute(self, context: Context, **kwargs: Any) -> Any: + task_instance = context.get("ti") + if task_instance is None: + raise AirflowException( + "DbtProducerWatcherKubernetesOperator expects a task instance in the execution context" + ) + + try_number = getattr(task_instance, "try_number", 1) + + if try_number > 1: + retry_message = ( + "DbtProducerWatcherKubernetesOperator does not support Airflow retries. " + f"Detected attempt #{try_number}; failing fast to avoid running a second dbt build." + ) + self.log.error(retry_message) + raise AirflowException(retry_message) + + # This global variable is used to make the task context available to the K8s callback. + # While the callback is set during the operator initialization, the context is only created during the operator's execution. + global producer_task_context + producer_task_context = context + return super().execute(context, **kwargs) + + +class DbtConsumerWatcherKubernetesSensor(BaseConsumerSensor, DbtRunKubernetesOperator): + template_fields: tuple[str, ...] = BaseConsumerSensor.template_fields + DbtRunKubernetesOperator.template_fields # type: ignore[operator] + + def use_event(self) -> bool: + return False + + +# This Operator does not seem to make sense for this particular execution mode, since build is executed by the producer task. +# That said, it is important to raise an exception if users attempt to use TestBehavior.BUILD, until we have a better experience. +class DbtBuildWatcherKubernetesOperator: + def __init__(self, *args: Any, **kwargs: Any): + raise NotImplementedError( + "`ExecutionMode.WATCHER` does not expose a DbtBuild operator, since the build command is executed by the producer task." + ) + + +class DbtSeedWatcherKubernetesOperator(DbtSeedMixin, DbtConsumerWatcherKubernetesSensor): # type: ignore[misc] + """ + Watches for the progress of dbt seed execution, run by the producer task (DbtProducerWatcherOperator). + """ + + template_fields: tuple[str, ...] = DbtConsumerWatcherKubernetesSensor.template_fields + DbtSeedMixin.template_fields # type: ignore[operator] + + def __init__(self, *args: Any, **kwargs: Any): + super().__init__(*args, **kwargs) + + +class DbtSnapshotWatcherKubernetesOperator(DbtSnapshotMixin, DbtConsumerWatcherKubernetesSensor): # type: ignore[misc] + """ + Watches for the progress of dbt snapshot execution, run by the producer task (DbtProducerWatcherOperator). + """ + + template_fields: tuple[str, ...] = DbtConsumerWatcherKubernetesSensor.template_fields + + +class DbtSourceWatcherKubernetesOperator(DbtSourceKubernetesOperator): + """ + Executes a dbt source freshness command, synchronously, as ExecutionMode.LOCAL. + """ + + template_fields: tuple[str, ...] = tuple(DbtSourceKubernetesOperator.template_fields) # type: ignore[arg-type] + + +class DbtRunWatcherKubernetesOperator(DbtConsumerWatcherKubernetesSensor): + """ + Watches for the progress of dbt model execution, run by the producer task (DbtProducerWatcherOperator). + """ + + template_fields: tuple[str, ...] = DbtConsumerWatcherKubernetesSensor.template_fields + DbtRunMixin.template_fields # type: ignore[operator] + + def __init__(self, *args: Any, **kwargs: Any): + super().__init__(*args, **kwargs) + + +class DbtTestWatcherKubernetesOperator(EmptyOperator): + """ + As a starting point, this operator does nothing. + We'll be implementing this operator as part of: https://github.com/astronomer/astronomer-cosmos/issues/1974 + """ + + def __init__(self, *args: Any, **kwargs: Any): + desired_keys = ("dag", "task_group", "task_id") + new_kwargs = {key: value for key, value in kwargs.items() if key in desired_keys} + super().__init__(**new_kwargs) # type: ignore[no-untyped-call] diff --git a/dev/dags/jaffle_shop_watcher_kubernetes.py b/dev/dags/jaffle_shop_watcher_kubernetes.py new file mode 100644 index 0000000000..acfb497c1d --- /dev/null +++ b/dev/dags/jaffle_shop_watcher_kubernetes.py @@ -0,0 +1,106 @@ +""" +## Jaffle Shop Airflow DAG using ExecutionMode.WATCHER_KUBERNETES + +[Jaffle Shop](https://github.com/dbt-labs/jaffle_shop) is a fictional eCommerce store. This dbt project originates from +dbt Labs as an example project with dummy data to demonstrate a working dbt core project. + +This DAG uses Cosmos in a way that there is a clear split between Airflow and dbt: +- The Airflow DAG is built using dbt manifest.json file +- The dbt commands are run inside Kubernetes pods + +This allows users to not have to install dbt in their Airflow deployment. + +This approach is a hybrid between the Cosmos ExecutionMode.KUBERNETES: +https://astronomer.github.io/astronomer-cosmos/getting_started/kubernetes.html#kubernetes + +And the Cosmos ExecutionMode.WATCHER: +https://astronomer.github.io/astronomer-cosmos/getting_started/watcher-execution-mode.html +""" + +import os +from pathlib import Path + +from airflow.providers.cncf.kubernetes.secret import Secret +from pendulum import datetime + +from cosmos import DbtDag +from cosmos.config import ( + ExecutionConfig, + ProfileConfig, + ProjectConfig, + RenderConfig, +) +from cosmos.constants import ExecutionMode, LoadMode + +DEFAULT_DBT_ROOT_PATH = Path(__file__).resolve().parent / "dbt" + +DBT_ROOT_PATH = Path(os.getenv("DBT_ROOT_PATH", DEFAULT_DBT_ROOT_PATH)) +AIRFLOW_DBT_PROJECT_DIR = DBT_ROOT_PATH / "jaffle_shop" + +K8S_PROJECT_DIR = "dags/dbt/jaffle_shop" +KBS_DBT_PROFILES_YAML_FILEPATH = Path(K8S_PROJECT_DIR) / "profiles.yml" + +DBT_IMAGE = "dbt-jaffle-shop:1.0.0" + +project_seeds = [{"project": "jaffle_shop", "seeds": ["raw_customers", "raw_payments", "raw_orders"]}] + +postgres_password_secret = Secret( + deploy_type="env", + deploy_target="POSTGRES_PASSWORD", + secret="postgres-secrets", + key="password", +) + +postgres_host_secret = Secret( + deploy_type="env", + deploy_target="POSTGRES_HOST", + secret="postgres-secrets", + key="host", +) + +operator_args = { + "deferrable": False, + "image": DBT_IMAGE, + "get_logs": True, + "is_delete_operator_pod": False, + "log_events_on_failure": True, + "secrets": [postgres_password_secret, postgres_host_secret], + "env_vars": { + "POSTGRES_DB": "postgres", + "POSTGRES_SCHEMA": "public", + "POSTGRES_USER": "postgres", + }, + "retry": 0, +} + +profile_config = ProfileConfig( + profile_name="postgres_profile", target_name="dev", profiles_yml_filepath=KBS_DBT_PROFILES_YAML_FILEPATH +) + +project_config = ProjectConfig( + project_name="jaffle_shop", + manifest_path=AIRFLOW_DBT_PROJECT_DIR / "target/manifest.json", +) + +render_config = RenderConfig(load_method=LoadMode.DBT_MANIFEST) + + +# Currently airflow dags test ignores priority_weight and weight_rule, for this reason, we're setting the following in the CI only: +if os.getenv("CI"): + operator_args["trigger_rule"] = "all_success" + +dag = DbtDag( + dag_id="jaffle_shop_watcher_kubernetes", + start_date=datetime(2022, 11, 27), + doc_md=__doc__, + catchup=False, + # Cosmos-specific parameters: + project_config=project_config, + profile_config=profile_config, + render_config=render_config, + execution_config=ExecutionConfig( + execution_mode=ExecutionMode.WATCHER_KUBERNETES, + dbt_project_path=K8S_PROJECT_DIR, + ), + operator_args=operator_args, +) diff --git a/pyproject.toml b/pyproject.toml index 9ad6d24578..0f832b3fb7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -229,6 +229,11 @@ serve = "sphinx-autobuild docs docs/_build" line-length = 120 target-version = ["py310", "py311", "py312", "py313"] +[tool.coverage.run] +omit = [ + "cosmos/airflow/_override.py" +] + [tool.isort] profile = "black" known_third_party = ["airflow", "jinja2"] diff --git a/scripts/test/integration-kubernetes.sh b/scripts/test/integration-kubernetes.sh index b8817cd870..c0980843ca 100644 --- a/scripts/test/integration-kubernetes.sh +++ b/scripts/test/integration-kubernetes.sh @@ -3,6 +3,13 @@ set -x set -e +# So we can validate ExecutionMode.WATCHER_KUBERNETES +actual_version=$(airflow version | cut -d. -f1) +if [ "$actual_version" = "3" ] ; then + pip install "apache-airflow-providers-cncf-kubernetes==10.8.0" "protobuf==6.33.2" + +fi + # Reset the Airflow database to its initial state airflow db reset -y @@ -12,4 +19,5 @@ pytest -vv \ --cov-report=term-missing \ --cov-report=xml \ -m 'integration and not dbtfusion' \ - tests/test_example_k8s_dags.py + tests/test_example_k8s_dags.py \ + tests/operators/test_watcher_kubernetes_integration.py diff --git a/scripts/test/integration.sh b/scripts/test/integration.sh index 30e48293e4..f988198076 100644 --- a/scripts/test/integration.sh +++ b/scripts/test/integration.sh @@ -22,4 +22,5 @@ pytest -vv \ --ignore=tests/perf \ --ignore=tests/test_async_example_dag.py \ --ignore=tests/test_example_k8s_dags.py \ - -k 'not (simple_dag_async or example_cosmos_python_models or example_virtualenv or jaffle_shop_kubernetes)' + --ignore=tests/operators/test_watcher_kubernetes_integration.py \ + -k 'not (simple_dag_async or example_cosmos_python_models or example_virtualenv or jaffle_shop_kubernetes or jaffle_shop_watcher_kubernetes)' diff --git a/scripts/test/unit-cov.sh b/scripts/test/unit-cov.sh index 50cd268c40..570735a7a4 100644 --- a/scripts/test/unit-cov.sh +++ b/scripts/test/unit-cov.sh @@ -8,4 +8,5 @@ pytest \ --ignore=tests/test_example_dags.py \ --ignore=tests/test_async_example_dag.py \ --ignore=tests/test_example_dags_no_connections.py \ - --ignore=tests/test_example_k8s_dags.py + --ignore=tests/test_example_k8s_dags.py \ + --ignore=tests/operators/test_watcher_kubernetes_integration.py diff --git a/scripts/test/unit.sh b/scripts/test/unit.sh index 373c109a73..18d65a2a51 100644 --- a/scripts/test/unit.sh +++ b/scripts/test/unit.sh @@ -5,4 +5,5 @@ pytest \ --ignore=tests/test_example_dags.py \ --ignore=tests/test_async_example_dag.py \ --ignore=tests/test_example_dags_no_connections.py \ - --ignore=tests/test_example_k8s_dags.py + --ignore=tests/test_example_k8s_dags.py \ + --ignore=tests/operators/test_watcher_kubernetes_integration.py diff --git a/tests/airflow/test_graph.py b/tests/airflow/test_graph.py index 9f30e27c54..2488be63ba 100644 --- a/tests/airflow/test_graph.py +++ b/tests/airflow/test_graph.py @@ -1691,3 +1691,24 @@ def test_skip_test_task_when_only_detached_tests_exist(): ] assert list(tasks_map.keys()) == expected_task_ids + + +def test_create_test_task_metadata_watcher_kubernetes_after_all(): + """ + Test that create_test_task_metadata creates a DbtTestKubernetesOperator + when test_behavior is AFTER_ALL and execution_mode is WATCHER_KUBERNETES. + """ + render_config = RenderConfig( + test_behavior=TestBehavior.AFTER_ALL, + ) + + metadata = create_test_task_metadata( + test_task_name="my_project_test", + execution_mode=ExecutionMode.WATCHER_KUBERNETES, + test_indirect_selection=TestIndirectSelection.EAGER, + task_args={"project_dir": SAMPLE_PROJ_PATH}, + render_config=render_config, + ) + + assert metadata.id == "my_project_test" + assert metadata.operator_class == "cosmos.operators.kubernetes.DbtTestKubernetesOperator" diff --git a/tests/operators/_watcher/__init__.py b/tests/operators/_watcher/__init__.py index e69de29bb2..9d48db4f9f 100644 --- a/tests/operators/_watcher/__init__.py +++ b/tests/operators/_watcher/__init__.py @@ -0,0 +1 @@ +from __future__ import annotations diff --git a/tests/operators/_watcher/test_triggerer.py b/tests/operators/_watcher/test_triggerer.py index b10f3d7dbe..c1b3a9ad4c 100644 --- a/tests/operators/_watcher/test_triggerer.py +++ b/tests/operators/_watcher/test_triggerer.py @@ -3,7 +3,8 @@ import pytest from packaging.version import Version -from cosmos.operators._watcher.triggerer import AIRFLOW_VERSION, WatcherTrigger +from cosmos.constants import AIRFLOW_VERSION +from cosmos.operators._watcher.triggerer import WatcherTrigger _real_import = __import__ @@ -141,7 +142,6 @@ async def fake_get_xcom_val(key): @pytest.mark.asyncio async def test_get_producer_task_status_airflow2(self): fetcher = MagicMock(return_value="failed") - with patch("cosmos.operators._watcher.triggerer.AIRFLOW_VERSION", Version("2.9.0")): with patch( "cosmos.operators._watcher.triggerer.build_producer_state_fetcher", return_value=fetcher @@ -183,7 +183,6 @@ async def test_get_producer_task_status_airflow3(self): @pytest.mark.asyncio async def test_get_producer_task_status_airflow3_missing_state(self): fetcher = MagicMock(return_value=None) - with patch("cosmos.operators._watcher.triggerer.AIRFLOW_VERSION", Version("3.0.0")): with patch("cosmos.operators._watcher.triggerer.build_producer_state_fetcher", return_value=fetcher): state = await self.trigger._get_producer_task_status() diff --git a/tests/operators/_watcher/test_watcher_base.py b/tests/operators/_watcher/test_watcher_base.py new file mode 100644 index 0000000000..775eeb66f1 --- /dev/null +++ b/tests/operators/_watcher/test_watcher_base.py @@ -0,0 +1,24 @@ +import pytest + +from cosmos.operators._watcher.base import BaseConsumerSensor +from cosmos.operators.local import DbtRunLocalOperator + + +class TestBaseConsumerSensor: + + def test__methods_to_be_implemented(self): + class SubclassBaseConsumerSensor(BaseConsumerSensor, DbtRunLocalOperator): + something_to_be_implemented = True + + sensor = SubclassBaseConsumerSensor( + task_id="test_sensor", + model_unique_id="model.jaffle_shop.stg_orders", + producer_task_id="dbt_run_local", + profile_config=None, + project_dir="/tmp/sample_project", + ) + with pytest.raises(NotImplementedError): + sensor.use_event() + + with pytest.raises(NotImplementedError): + assert sensor._get_status_from_events(None, None) is None diff --git a/tests/operators/test_watcher.py b/tests/operators/test_watcher.py index 7f313fc324..f631fa0022 100644 --- a/tests/operators/test_watcher.py +++ b/tests/operators/test_watcher.py @@ -18,7 +18,7 @@ from cosmos import DbtDag, ExecutionConfig, ProfileConfig, ProjectConfig, RenderConfig, TestBehavior from cosmos.config import InvocationMode from cosmos.constants import PRODUCER_WATCHER_DEFAULT_PRIORITY_WEIGHT, ExecutionMode -from cosmos.operators._watcher import WatcherTrigger +from cosmos.operators._watcher.triggerer import WatcherTrigger from cosmos.operators.watcher import ( DbtBuildWatcherOperator, DbtConsumerWatcherSensor, diff --git a/tests/operators/test_watcher_kubernetes_integration.py b/tests/operators/test_watcher_kubernetes_integration.py new file mode 100644 index 0000000000..dfb4e99897 --- /dev/null +++ b/tests/operators/test_watcher_kubernetes_integration.py @@ -0,0 +1,142 @@ +import os +from datetime import datetime +from pathlib import Path + +import pytest +from airflow.providers.cncf.kubernetes.secret import Secret + +from cosmos import DbtDag +from cosmos.config import ExecutionConfig, ProfileConfig, ProjectConfig, RenderConfig +from cosmos.constants import AIRFLOW_VERSION, ExecutionMode, LoadMode, TestBehavior, Version +from cosmos.operators.watcher_kubernetes import ( + DbtProducerWatcherKubernetesOperator, + DbtRunWatcherKubernetesOperator, + DbtSeedWatcherKubernetesOperator, +) +from tests.utils import run_dag + +DEFAULT_DBT_ROOT_PATH = Path(__file__).parent.parent.parent / "dev/dags/dbt" + +DBT_ROOT_PATH = Path(os.getenv("DBT_ROOT_PATH", DEFAULT_DBT_ROOT_PATH)) +AIRFLOW_DBT_PROJECT_DIR = DBT_ROOT_PATH / "jaffle_shop" + +K8S_PROJECT_DIR = "dags/dbt/jaffle_shop" +KBS_DBT_PROFILES_YAML_FILEPATH = Path(K8S_PROJECT_DIR) / "profiles.yml" + +DBT_IMAGE = "dbt-jaffle-shop:1.0.0" + +project_seeds = [{"project": "jaffle_shop", "seeds": ["raw_customers", "raw_payments", "raw_orders"]}] + +postgres_password_secret = Secret( + deploy_type="env", + deploy_target="POSTGRES_PASSWORD", + secret="postgres-secrets", + key="password", +) + +postgres_host_secret = Secret( + deploy_type="env", + deploy_target="POSTGRES_HOST", + secret="postgres-secrets", + key="host", +) + +operator_args = { + "deferrable": False, + "image": DBT_IMAGE, + "get_logs": True, + "is_delete_operator_pod": False, + "log_events_on_failure": True, + "secrets": [postgres_password_secret, postgres_host_secret], + "env_vars": { + "POSTGRES_DB": "postgres", + "POSTGRES_SCHEMA": "public", + "POSTGRES_USER": "postgres", + }, + "retry": 0, +} + +profile_config = ProfileConfig( + profile_name="postgres_profile", target_name="dev", profiles_yml_filepath=KBS_DBT_PROFILES_YAML_FILEPATH +) + +project_config = ProjectConfig( + project_name="jaffle_shop", + manifest_path=AIRFLOW_DBT_PROJECT_DIR / "target/manifest.json", +) + +render_config = RenderConfig(load_method=LoadMode.DBT_MANIFEST, test_behavior=TestBehavior.NONE) + + +# Currently airflow dags test ignores priority_weight and weight_rule, for this reason, we're setting the following in the CI only: +if os.getenv("CI"): + operator_args["trigger_rule"] = "all_success" + + +@pytest.mark.skipif( + AIRFLOW_VERSION < Version("3.1"), + reason="We are only testing watcher Kubernetes with Airflow 3.1 and more recent versions of the K8s provider", +) +@pytest.mark.integration +def test_dbt_dag_with_watcher_kubernetes(): + """ + Create an Cosmos DbtDag with `ExecutionMode.WATCHER_KUBERNETES`. + Confirm the right amount of tasks is created and that tasks are in the expected topological order. + Confirm that the producer watcher task is created and that it is the parent of the root dbt nodes. + """ + + dag_dbt_watcher_kubernetes = DbtDag( + dag_id="watcher_kubernetes_dag", + start_date=datetime(2022, 11, 27), + doc_md=__doc__, + catchup=False, + # Cosmos-specific parameters: + project_config=project_config, + profile_config=profile_config, + render_config=render_config, + execution_config=ExecutionConfig( + execution_mode=ExecutionMode.WATCHER_KUBERNETES, + dbt_project_path=K8S_PROJECT_DIR, + ), + operator_args=operator_args, + ) + + run_dag(dag_dbt_watcher_kubernetes) + + assert len(dag_dbt_watcher_kubernetes.task_dict) == 9 + tasks_names = [task.task_id for task in dag_dbt_watcher_kubernetes.topological_sort()] + + expected_task_names = [ + "dbt_producer_watcher", + "raw_customers_seed", + "raw_orders_seed", + "raw_payments_seed", + "stg_customers_run", + "stg_orders_run", + "stg_payments_run", + "customers_run", + "orders_run", + ] + assert tasks_names == expected_task_names + + assert isinstance( + dag_dbt_watcher_kubernetes.task_dict["dbt_producer_watcher"], + DbtProducerWatcherKubernetesOperator, + ) + assert isinstance(dag_dbt_watcher_kubernetes.task_dict["raw_customers_seed"], DbtSeedWatcherKubernetesOperator) + assert isinstance(dag_dbt_watcher_kubernetes.task_dict["raw_orders_seed"], DbtSeedWatcherKubernetesOperator) + assert isinstance(dag_dbt_watcher_kubernetes.task_dict["raw_payments_seed"], DbtSeedWatcherKubernetesOperator) + assert isinstance(dag_dbt_watcher_kubernetes.task_dict["stg_customers_run"], DbtRunWatcherKubernetesOperator) + assert isinstance(dag_dbt_watcher_kubernetes.task_dict["stg_orders_run"], DbtRunWatcherKubernetesOperator) + assert isinstance(dag_dbt_watcher_kubernetes.task_dict["stg_payments_run"], DbtRunWatcherKubernetesOperator) + assert isinstance(dag_dbt_watcher_kubernetes.task_dict["customers_run"], DbtRunWatcherKubernetesOperator) + assert isinstance(dag_dbt_watcher_kubernetes.task_dict["orders_run"], DbtRunWatcherKubernetesOperator) + + expected_downstream_task_ids = { + "raw_payments_seed", + "raw_orders_seed", + "raw_customers_seed", + } + assert ( + dag_dbt_watcher_kubernetes.task_dict["dbt_producer_watcher"].downstream_task_ids == expected_downstream_task_ids + ) diff --git a/tests/operators/test_watcher_kubernetes_unit.py b/tests/operators/test_watcher_kubernetes_unit.py new file mode 100644 index 0000000000..4ef9207cd6 --- /dev/null +++ b/tests/operators/test_watcher_kubernetes_unit.py @@ -0,0 +1,224 @@ +import logging +import os +from pathlib import Path +from unittest.mock import MagicMock, patch + +import pytest +from airflow.exceptions import AirflowException +from airflow.providers.cncf.kubernetes import __version__ as airflow_k8s_provider_version +from airflow.providers.cncf.kubernetes.secret import Secret +from packaging.version import Version + +from cosmos.config import ProfileConfig, ProjectConfig, RenderConfig +from cosmos.constants import LoadMode, TestBehavior, _K8s_WATCHER_MIN_K8S_PROVIDER_VERSION + +if Version(airflow_k8s_provider_version) < _K8s_WATCHER_MIN_K8S_PROVIDER_VERSION: + pytest.skip( + f"Watcher Kubernetes depends on apache-airflow-providers-cncf-kubernetes >= {_K8s_WATCHER_MIN_K8S_PROVIDER_VERSION}. Currenl version: {airflow_k8s_provider_version} ", + allow_module_level=True, + ) +else: + from cosmos.operators.watcher_kubernetes import ( + DbtBuildWatcherKubernetesOperator, + DbtConsumerWatcherKubernetesSensor, + DbtProducerWatcherKubernetesOperator, + ) + +DEFAULT_DBT_ROOT_PATH = Path(__file__).parent.parent.parent / "dev/dags/dbt" + +DBT_ROOT_PATH = Path(os.getenv("DBT_ROOT_PATH", DEFAULT_DBT_ROOT_PATH)) +AIRFLOW_DBT_PROJECT_DIR = DBT_ROOT_PATH / "jaffle_shop" + +K8S_PROJECT_DIR = "dags/dbt/jaffle_shop" +KBS_DBT_PROFILES_YAML_FILEPATH = Path(K8S_PROJECT_DIR) / "profiles.yml" + +DBT_IMAGE = "dbt-jaffle-shop:1.0.0" + +project_seeds = [{"project": "jaffle_shop", "seeds": ["raw_customers", "raw_payments", "raw_orders"]}] + +postgres_password_secret = Secret( + deploy_type="env", + deploy_target="POSTGRES_PASSWORD", + secret="postgres-secrets", + key="password", +) + +postgres_host_secret = Secret( + deploy_type="env", + deploy_target="POSTGRES_HOST", + secret="postgres-secrets", + key="host", +) + +operator_args = { + "deferrable": False, + "image": DBT_IMAGE, + "get_logs": True, + "is_delete_operator_pod": False, + "log_events_on_failure": True, + "secrets": [postgres_password_secret, postgres_host_secret], + "env_vars": { + "POSTGRES_DB": "postgres", + "POSTGRES_SCHEMA": "public", + "POSTGRES_USER": "postgres", + }, + "retry": 0, +} + +profile_config = ProfileConfig( + profile_name="postgres_profile", target_name="dev", profiles_yml_filepath=KBS_DBT_PROFILES_YAML_FILEPATH +) + +project_config = ProjectConfig( + project_name="jaffle_shop", + manifest_path=AIRFLOW_DBT_PROJECT_DIR / "target/manifest.json", +) + +render_config = RenderConfig(load_method=LoadMode.DBT_MANIFEST, test_behavior=TestBehavior.NONE) + + +def test_retries_set_to_zero_on_init(): + """ + Test that the operator sets retries to 0 during initialization. + """ + op = DbtProducerWatcherKubernetesOperator( + project_dir=".", + profile_config=None, + image="dbt-image:latest", + ) + assert op.retries == 0 + + +def test_retries_overridden_even_if_user_sets_them(): + """ + Test that even if a user explicitly sets retries, they are overridden to 0. + """ + op = DbtProducerWatcherKubernetesOperator( + project_dir=".", + profile_config=None, + image="dbt-image:latest", + retries=5, + ) + assert op.retries == 0 + + +@patch("cosmos.operators.kubernetes.DbtBuildKubernetesOperator.execute") +def test_blocks_retry_attempt(mock_execute, caplog): + """ + Test that the operator raises an AirflowException when a retry is attempted (try_number > 1). + """ + op = DbtProducerWatcherKubernetesOperator( + project_dir=".", + profile_config=None, + image="dbt-image:latest", + ) + + ti = MagicMock() + ti.try_number = 2 + context = {"ti": ti} + + with caplog.at_level(logging.ERROR): + with pytest.raises(AirflowException) as excinfo: + op.execute(context=context) + + mock_execute.assert_not_called() + assert "does not support Airflow retries" in str(excinfo.value) + assert any("does not support Airflow retries" in message for message in caplog.messages) + + +def test_raises_exception_when_task_instance_missing(): + """ + Test that the operator raises an AirflowException when task instance is missing from context. + """ + op = DbtProducerWatcherKubernetesOperator( + project_dir=".", + profile_config=None, + image="dbt-image:latest", + ) + + context = {"ti": None} + + with pytest.raises(AirflowException) as excinfo: + op.execute(context=context) + + assert "expects a task instance" in str(excinfo.value) + + +def test_dbt_build_watcher_kubernetes_operator_raises_not_implemented_error(): + expected_message = ( + "`ExecutionMode.WATCHER` does not expose a DbtBuild operator, " + "since the build command is executed by the producer task." + ) + + with pytest.raises(NotImplementedError, match=expected_message): + DbtBuildWatcherKubernetesOperator() + + +def make_sensor(**kwargs): + extra_context = {"dbt_node_config": {"unique_id": "model.jaffle_shop.stg_orders"}} + kwargs["extra_context"] = extra_context + sensor = DbtConsumerWatcherKubernetesSensor( + task_id="model.my_model", + project_dir="/tmp/project", + profile_config=None, + deferrable=False, + image="dbt-image:latest", + **kwargs, + ) + sensor._get_producer_task_status = MagicMock(return_value=None) + return sensor + + +def make_context(ti_mock, *, run_id: str = "test-run", map_index: int = 0): + return { + "ti": ti_mock, + "run_id": run_id, + "task_instance": MagicMock(map_index=map_index), + } + + +def test_first_execution_behaves_as_base_consumer_sensor(): + """ + On the first execution (try_number == 1), the sensor should poke for status + from XCom, behaving as BaseConsumerSensor. + """ + sensor = make_sensor() + + ti = MagicMock() + ti.try_number = 1 + ti.xcom_pull.return_value = "success" + context = make_context(ti) + + result = sensor.poke(context) + + assert result is True + ti.xcom_pull.assert_called() + + +@patch("cosmos.operators.kubernetes.DbtKubernetesBaseOperator.build_and_run_cmd") +def test_retry_executes_as_dbt_run_kubernetes_operator(mock_build_and_run_cmd): + """ + On retry (try_number > 1), the sensor should fall back to executing + as DbtRunKubernetesOperator by calling build_and_run_cmd. + """ + sensor = make_sensor() + + ti = MagicMock() + ti.try_number = 2 + ti.xcom_pull.return_value = None + ti.task.dag.get_task.return_value.add_cmd_flags.return_value = ["--threads", "2"] + context = make_context(ti) + + result = sensor.poke(context) + + assert result is True + mock_build_and_run_cmd.assert_called_once() + + +def test_use_event_returns_false(): + """ + DbtConsumerWatcherKubernetesSensor should return False for use_event(), + meaning it uses XCom-based status retrieval instead of events. + """ + sensor = make_sensor() + assert sensor.use_event() is False diff --git a/tests/test_example_dags.py b/tests/test_example_dags.py index d75b896516..314064b6f6 100644 --- a/tests/test_example_dags.py +++ b/tests/test_example_dags.py @@ -19,13 +19,13 @@ EXAMPLE_DAGS_DIR = Path(__file__).parent.parent / "dev/dags" AIRFLOW_IGNORE_FILE = EXAMPLE_DAGS_DIR / ".airflowignore" DBT_VERSION = Version(get_dbt_version().to_version_string()[1:]) -KUBERNETES_DAGS = ["jaffle_shop_kubernetes"] +KUBERNETES_DAGS = ["jaffle_shop_kubernetes", "jaffle_shop_watcher_kubernetes"] MIN_VER_DAG_FILE: dict[str, list[str]] = { "2.8": ["cosmos_manifest_example.py", "simple_dag_async.py", "cosmos_callback_dag.py"], } -IGNORED_DAG_FILES = ["performance_dag.py", "jaffle_shop_kubernetes.py"] +IGNORED_DAG_FILES = ["performance_dag.py", "jaffle_shop_kubernetes.py", "jaffle_shop_watcher_kubernetes.py"] # Sort descending based on Versions and convert string to an actual version MIN_VER_DAG_FILE_VER: dict[Version, list[str]] = { diff --git a/tests/test_example_dags_no_connections.py b/tests/test_example_dags_no_connections.py index 94cae87877..cb7eb68486 100644 --- a/tests/test_example_dags_no_connections.py +++ b/tests/test_example_dags_no_connections.py @@ -17,7 +17,7 @@ "2.8": ["cosmos_manifest_example.py"], } -IGNORED_DAG_FILES = ["performance_dag.py", "jaffle_shop_kubernetes.py"] +IGNORED_DAG_FILES = ["performance_dag.py", "jaffle_shop_kubernetes.py", "jaffle_shop_watcher_kubernetes.py"] # Sort descending based on Versions and convert string to an actual version MIN_VER_DAG_FILE_VER: dict[Version, list[str]] = { diff --git a/tests/test_example_k8s_dags.py b/tests/test_example_k8s_dags.py index d789505b93..0c0cb6a6b5 100644 --- a/tests/test_example_k8s_dags.py +++ b/tests/test_example_k8s_dags.py @@ -6,12 +6,14 @@ from airflow.utils.db import create_default_connections from airflow.utils.session import provide_session +from cosmos.constants import _K8s_WATCHER_MIN_K8S_PROVIDER_VERSION + from . import utils as test_utils EXAMPLE_DAGS_DIR = Path(__file__).parent.parent / "dev/dags" AIRFLOW_IGNORE_FILE = EXAMPLE_DAGS_DIR / ".airflowignore" -KUBERNETES_DAG_FILES = ["jaffle_shop_kubernetes.py"] +KUBERNETES_DAG_FILES = ["jaffle_shop_kubernetes.py", "jaffle_shop_watcher_kubernetes.py"] @provide_session @@ -39,7 +41,24 @@ def get_all_dag_files(): def test_example_dag_kubernetes(session): get_all_dag_files() db = DagBag(EXAMPLE_DAGS_DIR, include_examples=False) - # for dag_id in KUBERNETES_DAG_FILES: + assert not db.import_errors dag = db.get_dag("jaffle_shop_kubernetes") + test_utils.run_dag(dag) + + +from airflow.providers.cncf.kubernetes import __version__ as airflow_k8s_provider_version +from packaging.version import Version + + +@pytest.mark.skipif( + Version(airflow_k8s_provider_version) < _K8s_WATCHER_MIN_K8S_PROVIDER_VERSION, + reason="This feature is only available for K8s provider 10.8.0 and above", +) +@pytest.mark.integration +def test_example_dag_watcher_kubernetes(session): + get_all_dag_files() + db = DagBag(EXAMPLE_DAGS_DIR, include_examples=False) + dag = db.get_dag("jaffle_shop_watcher_kubernetes") assert not db.import_errors + assert dag is not None test_utils.run_dag(dag)