diff --git a/.github/workflows/actionlint.yml b/.github/workflows/actionlint.yml index 812ac5368a..efd1fd6d37 100644 --- a/.github/workflows/actionlint.yml +++ b/.github/workflows/actionlint.yml @@ -20,7 +20,7 @@ jobs: - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 - name: Validate workflow files - uses: reviewdog/action-actionlint@0d952c597ef8459f634d7145b0b044a9699e5e43 # v1.71.0 + uses: reviewdog/action-actionlint@6fb7acc99f4a1008869fa8a0f09cfca740837d9d # v1.72.0 with: github_token: ${{ secrets.GITHUB_TOKEN }} reporter: github-check diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index b75da2f631..59bbed92d4 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -85,6 +85,7 @@ jobs: .github/ISSUE_TEMPLATE/*) ;; .github/dependabot.yml) ;; .github/workflows/docs.yml) ;; + .github/workflows/docs-build.yml) ;; .github/workflows/stale.yml) ;; .github/workflows/actionlint.yml) ;; .github/workflows/codeql.yml) ;; @@ -243,6 +244,7 @@ jobs: env: AIRFLOW__COSMOS__ENABLE_CACHE_DBT_LS: 0 AIRFLOW__COSMOS__ENABLE_CACHE_DBT_YAML_SELECTORS: 0 + AIRFLOW__COSMOS__ENABLE_LAX_SELECTOR_PARSING: 0 AIRFLOW_HOME: /home/runner/work/astronomer-cosmos/astronomer-cosmos/ AIRFLOW__CORE__DAGBAG_IMPORT_TIMEOUT: 120 AIRFLOW__CORE__DAG_FILE_PROCESSOR_TIMEOUT: 120 @@ -575,6 +577,7 @@ jobs: env: AIRFLOW__COSMOS__ENABLE_CACHE_DBT_LS: 0 AIRFLOW__COSMOS__ENABLE_CACHE_DBT_YAML_SELECTORS: 0 + AIRFLOW__COSMOS__ENABLE_LAX_SELECTOR_PARSING: 0 AIRFLOW_HOME: /home/runner/work/astronomer-cosmos/astronomer-cosmos/ AIRFLOW__CORE__DAGBAG_IMPORT_TIMEOUT: 90.0 PYTHONPATH: /home/runner/work/astronomer-cosmos/astronomer-cosmos/:$PYTHONPATH diff --git a/.gitignore b/.gitignore index e2c14d9eba..8a8d245d87 100644 --- a/.gitignore +++ b/.gitignore @@ -11,6 +11,9 @@ actionlint # dbt_packages is a directory that gets created when you run dbt deps dev/dags/dbt/jaffle_shop/dbt_packages/ +# dbt package lockfile — removed so dbt resolves the version range in packages.yml at runtime +dev/dags/dbt/jaffle_shop/package-lock.yml + # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index c745c67d86..7fe603a0f5 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -100,29 +100,8 @@ repos: - id: blacken-docs alias: black additional_dependencies: [black>=26.3.1] - - repo: local - hooks: - - id: check-docs-build - name: Ensure the docs build without errors - entry: sphinx-build -b html docs docs/_build --fail-on-warning --fresh-env - language: python - additional_dependencies: - [ - "aenum", - "deprecation", - "msgpack", - "apache-airflow", - "pydata-sphinx-theme>=0.16.0", - "sphinx", - "sphinx-autoapi", - "sphinx-autobuild", - "sphinx-reredirects", - "sphinxcontrib.mermaid", - ] - pass_filenames: false - files: ^docs/ - repo: https://github.com/pre-commit/mirrors-mypy - rev: "v1.19.1" + rev: "v1.20.0" hooks: - id: mypy name: mypy-python @@ -142,7 +121,7 @@ exclude: "dev/dags/dbt/simple/dbt_packages/.*" ci: autofix_commit_msg: 🎨 [pre-commit.ci] Auto format from pre-commit.com hooks - autoupdate_commit_msg: ⬆ [pre-commit.ci] pre-commit autoupdate + autoupdate_schedule: quarterly skip: - mypy # build of https://github.com/pre-commit/mirrors-mypy:types-PyYAML,types-attrs,attrs,types-requests, #types-python-dateutil,apache-airflow@v1.5.0 for python@python3 exceeds tier max size 250MiB: 262.6MiB diff --git a/CHANGELOG.rst b/CHANGELOG.rst index baa266a52a..c11a8ec5f3 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -1,6 +1,42 @@ Changelog ========= +1.14.1 (2026-04-23) +------------------- + +Bug Fixes + +* Fix ``ExecutionMode.WATCHER`` producer retry behaviour by @tatiana in #2559 +* Prevent watcher producer skip propagating to downstream tasks via gateway task by @johnhoran and @tatiana in #2597 +* Keep watcher sensor polling when producer is still running by @pankajkoti in #2592 +* Fix circular import error in Cosmos plugin discovery under Astro Runtime by @tatiana in #2538 +* Fix ``CosmosRichLogger`` crash on ``None`` log message by @tatiana in #2540 +* Enable inlets and outlets using dbt Fusion on Airflow 3 by @ichirotakami in #2561 +* Fix incorrectly skipped source downstream tasks in ``ExecutionMode.WATCHER`` by @pankajastro in #2563 +* Fix duplicate logs in ``dbt build`` when source freshness is enabled by @pankajastro in #2564 +* Warn and normalize when ``source_rendering_behavior=None`` is passed by @pankajastro in #2570 +* Gracefully handle ``Variable.set()`` failures on Astro Remote Execution by @hkc-8010 in #2573 +* Skip malformed YAML selectors instead of failing entirely by @YourRoyalLinus in #2577 + +Docs + +* Update watcher test behavior docs for Cosmos 1.14.0 by @tatiana in #2549 +* Add redirect for moved partial-parsing docs page by @tatiana in #2550 +* Document ``ExecutionMode.WATCHER`` and ``depends_on_past`` limitation by @tatiana in #2602 +* Restore memory-optimised imports docs for Cosmos < 1.14.0 by @pankajkoti in #2604 + +Others + +* Speed up Airflow 3.1+ integration tests by caching InProcessExecutionAPI by @pankajkoti in #2547 +* Improve stability of cache hash unit tests by @tatiana in #2539 +* Fix mypy 1.20.0 type check failures by @pankajkoti in #2546 +* Fix CI failures caused by docs build memory exhaustion by @pankajkoti in #2580 +* Fix dbt Fusion broken integration tests by @tatiana in #2581 +* Fix flaky ``cosmos_manifest_selectors_example`` DAG in CI by @pankajkoti in #2593 +* Reduce pre-commit autoupdate frequency PRs by @tatiana in #2544 +* Bump ``reviewdog/action-actionlint`` from 1.71.0 to 1.72.0 by @dependabot in #2542 +* Skip watcher gateway test on Airflow 3.0 by @tatiana in #2607 + 1.14.0 (2026-04-07) --------------------- diff --git a/cosmos/__init__.py b/cosmos/__init__.py index 00bed8945a..182642d863 100644 --- a/cosmos/__init__.py +++ b/cosmos/__init__.py @@ -10,7 +10,7 @@ import importlib -__version__ = "1.14.0" +__version__ = "1.14.1" # Mapping of public names to their module paths for lazy loading via __getattr__. diff --git a/cosmos/airflow/graph.py b/cosmos/airflow/graph.py index 73de6c5748..08cf9a1f7e 100644 --- a/cosmos/airflow/graph.py +++ b/cosmos/airflow/graph.py @@ -25,6 +25,7 @@ DBT_SETUP_ASYNC_TASK_ID, DBT_TEARDOWN_ASYNC_TASK_ID, DEFAULT_DBT_RESOURCES, + PRODUCER_WATCHER_DONE_TASK_ID, PRODUCER_WATCHER_TASK_ID, SUPPORTED_BUILD_RESOURCES, TESTABLE_DBT_RESOURCES, @@ -664,7 +665,7 @@ def _add_dbt_setup_async_task( else [task_or_taskgroup] ) for task in node_tasks: - task.producer_task_id = setup_airflow_task.task_id # type: ignore[attr-defined] + task.producer_task_id = setup_airflow_task.task_id # type: ignore[union-attr] if not task.upstream_list: setup_airflow_task >> task @@ -716,6 +717,20 @@ def _add_watcher_producer_task( ) producer_airflow_task = create_airflow_task(producer_task_metadata, dag, task_group=task_group) tasks_map[PRODUCER_WATCHER_TASK_ID] = producer_airflow_task + + # For DbtTaskGroup, add a gate task that absorbs the producer's skip state + # so it does not propagate to tasks downstream of the group. + # Not needed for DbtDag where producer >> consumers with trigger_rule="always" handles this. + if task_group is not None: + from cosmos.operators._watcher.base import create_producer_done_task + + producer_done_task = create_producer_done_task( + dag=dag, + task_group=task_group, + task_id=PRODUCER_WATCHER_DONE_TASK_ID, + ) + producer_airflow_task >> producer_done_task + return producer_airflow_task @@ -732,8 +747,8 @@ def _add_watcher_dependencies( - make the producer task to be the parent of the root dbt nodes, without blocking them from sensing XCom """ for node_id, task_or_taskgroup in tasks_map.items(): - # We do not want to set a dependency between the producer task and itself - if node_id == PRODUCER_WATCHER_TASK_ID: + # We do not want to set a dependency between the producer task (or its gate) and itself + if node_id in (PRODUCER_WATCHER_TASK_ID, PRODUCER_WATCHER_DONE_TASK_ID): continue node_tasks = ( @@ -742,7 +757,7 @@ def _add_watcher_dependencies( else [task_or_taskgroup] ) for task in node_tasks: - task.producer_task_id = producer_airflow_task.task_id # type: ignore[attr-defined] + task.producer_task_id = producer_airflow_task.task_id # type: ignore[union-attr] # Make the producer task to be the parent of the root dbt nodes, without blocking them from sensing XCom # We only managed to do this in the case of DbtDag. @@ -761,7 +776,7 @@ def _add_watcher_dependencies( else: always_run_tasks = [task_or_taskgroup] for task in always_run_tasks: - task.trigger_rule = task_args.get("trigger_rule", "always") # type: ignore[attr-defined] + task.trigger_rule = task_args.get("trigger_rule", "always") # type: ignore[union-attr] def should_create_detached_nodes(render_config: RenderConfig) -> bool: diff --git a/cosmos/config.py b/cosmos/config.py index 24591eeab8..5998ab2888 100644 --- a/cosmos/config.py +++ b/cosmos/config.py @@ -9,7 +9,7 @@ from collections.abc import Callable, Iterator from dataclasses import InitVar, dataclass, field from pathlib import Path -from typing import Any +from typing import TYPE_CHECKING, Any import yaml @@ -34,7 +34,9 @@ from cosmos.dbt.executable import get_system_dbt from cosmos.exceptions import CosmosValueError from cosmos.log import get_logger -from cosmos.profiles import BaseProfileMapping + +if TYPE_CHECKING: + from cosmos.profiles import BaseProfileMapping logger = get_logger(__name__) @@ -116,6 +118,14 @@ def __post_init__(self, dbt_project_path: str | Path | None) -> None: "RenderConfig.dbt_deps is deprecated since Cosmos 1.9 and will be removed in Cosmos 2.0. Use ProjectConfig.install_dbt_deps instead.", DeprecationWarning, ) + if self.source_rendering_behavior is None: + warnings.warn( + "Passing None for source_rendering_behavior is not supported. " + "Use SourceRenderingBehavior.NONE to disable source rendering. " + "Defaulting to SourceRenderingBehavior.NONE.", + UserWarning, + ) + self.source_rendering_behavior = SourceRenderingBehavior.NONE self.project_path = Path(dbt_project_path) if dbt_project_path else None # allows us to initiate this attribute from Path objects and str self.dbt_ls_path = Path(self.dbt_ls_path) if self.dbt_ls_path else None @@ -321,7 +331,9 @@ def validate_profiles_yml(self) -> None: raise CosmosValueError(f"The file {self.profiles_yml_filepath} does not exist.") def get_profile_type(self) -> str: - if isinstance(self.profile_mapping, BaseProfileMapping): + from cosmos.profiles.base import BaseProfileMapping as _BaseProfileMapping + + if isinstance(self.profile_mapping, _BaseProfileMapping): return str(self.profile_mapping.dbt_profile_type) profile_path = self._get_profile_path() diff --git a/cosmos/constants.py b/cosmos/constants.py index 1f6b3439dd..0fd938ced4 100644 --- a/cosmos/constants.py +++ b/cosmos/constants.py @@ -185,6 +185,7 @@ def _missing_value_(cls, value): # type: ignore CONSUMER_WATCHER_DEFAULT_PRIORITY_WEIGHT = 2 PRODUCER_WATCHER_DEFAULT_PRIORITY_WEIGHT = 20 PRODUCER_WATCHER_TASK_ID = "dbt_producer_watcher" +PRODUCER_WATCHER_DONE_TASK_ID = f"{PRODUCER_WATCHER_TASK_ID}_done" # Historical telemetry endpoints retained for reference: # • v1 (Cosmos 1.8.0–1.10.x) diff --git a/cosmos/dbt/graph.py b/cosmos/dbt/graph.py index fad70922b8..9e090c03ea 100644 --- a/cosmos/dbt/graph.py +++ b/cosmos/dbt/graph.py @@ -553,6 +553,40 @@ def get_dbt_yaml_selectors_cache_key_args(self, impl_version: str) -> list[str]: logger.debug(f"Value of `dbt_yaml_selectors_cache_key` for <{self.cache_key}>: {cache_args}") return cache_args + def _save_cache_to_variable(self, cache_dict: dict[str, Any], cache_name: str) -> None: + """Write cache_dict to an Airflow Variable, warning on AirflowRuntimeError. + + This failure can be expected in DAG parsing environments where the scheduler or + DAG processor does not have direct access to a usable Airflow metadata database + (for example, when the ``variable`` table is unavailable). + """ + try: + from airflow.sdk.exceptions import AirflowRuntimeError + except ImportError: + Variable.set(self.cache_key, cache_dict, serialize_json=True) + return + + is_yaml_cache = "yaml" in cache_name.lower() + disable_cache_env_var = ( + "AIRFLOW__COSMOS__ENABLE_CACHE_DBT_YAML_SELECTORS" + if is_yaml_cache + else "AIRFLOW__COSMOS__ENABLE_CACHE_DBT_LS" + ) + cache_specific_workaround = "" if is_yaml_cache else ", using LoadMode.DBT_MANIFEST" + try: + Variable.set(self.cache_key, cache_dict, serialize_json=True) + except AirflowRuntimeError as e: + logger.warning( + "Failed to save Cosmos %s cache to Airflow Variable '%s': %s. " + "Consider setting AIRFLOW__COSMOS__REMOTE_CACHE_DIR to use object storage for caching%s, " + "or disabling the cache via %s.", + cache_name, + self.cache_key, + e, + cache_specific_workaround, + disable_cache_env_var, + ) + def save_dbt_ls_cache(self, dbt_ls_output: str) -> None: """ Store compressed dbt ls output into an Airflow Variable. @@ -582,7 +616,7 @@ def save_dbt_ls_cache(self, dbt_ls_output: str) -> None: with remote_cache_key_path.open("w") as fp: json.dump(cache_dict, fp) else: - Variable.set(self.cache_key, cache_dict, serialize_json=True) + self._save_cache_to_variable(cache_dict, "dbt ls") def _get_dbt_ls_remote_cache(self, remote_cache_dir: Path | ObjectStoragePath) -> dict[str, str]: """Loads the remote cache for dbt ls.""" @@ -1104,7 +1138,7 @@ def save_yaml_selectors_cache(self, yaml_selectors: YamlSelectors) -> None: with remote_cache_key_path.open("w") as fp: json.dump(cache_dict, fp) else: - Variable.set(self.cache_key, cache_dict, serialize_json=True) + self._save_cache_to_variable(cache_dict, "YAML selectors") def parse_yaml_selectors(self, selector_definitions: dict[str, Any]) -> YamlSelectors: """ @@ -1116,8 +1150,9 @@ def parse_yaml_selectors(self, selector_definitions: dict[str, Any]) -> YamlSele Returns: YamlSelectors: A YamlSelectors instance """ - - yaml_selectors = YamlSelectors.parse(selector_definitions) + yaml_selectors = YamlSelectors.parse( + selector_definitions, lax_parsing_enabled=settings.enable_lax_selector_parsing + ) if self.should_use_yaml_selectors_cache(): self.save_yaml_selectors_cache(yaml_selectors) @@ -1194,9 +1229,11 @@ def _apply_manifest_node_selection(self, nodes: dict[str, DbtNode], manifest: di yaml_selectors = self.load_parsed_selectors(selector_definitions) selections = yaml_selectors.get_parsed(self.render_config.selector) if not selections: - raise CosmosLoadDbtException( - f"Selector `{self.render_config.selector}` not found in parsed YAML selectors `{selector_definitions}`" - ) + error_message = f"Selector `{self.render_config.selector}` not found in parsed YAML selectors" + if settings.enable_lax_selector_parsing: + error_message += ". Check logs - the selector may have parsing errors logged as warnings" + raise CosmosLoadDbtException(error_message) + self.nodes = nodes self.filtered_nodes = select_nodes( project_dir=project_dir, diff --git a/cosmos/dbt/selector.py b/cosmos/dbt/selector.py index cdde43a198..9c583f71f7 100644 --- a/cosmos/dbt/selector.py +++ b/cosmos/dbt/selector.py @@ -1349,7 +1349,7 @@ def _parse_from_definition( return (select, exclude, errors) @classmethod - def parse(cls, selectors: dict[str, dict[str, Any]]) -> YamlSelectors: + def parse(cls, selectors: dict[str, dict[str, Any]], lax_parsing_enabled: bool) -> YamlSelectors: """ Parse selector definitions from a dbt manifest into a YamlSelectors instance. @@ -1357,6 +1357,7 @@ def parse(cls, selectors: dict[str, dict[str, Any]]) -> YamlSelectors: selector definitions from the manifest and converts them to Cosmos format. :param selectors: dict[str, dict[str, Any]] - Dictionary of selector definitions from the manifest, keyed by selector name + :param lax_parsing_enabled: bool - Flag to determine whether parser exceptions should be logged instead of raised :return: YamlSelectors - A YamlSelectors instance containing both raw and parsed selector definitions Example Input: @@ -1386,12 +1387,24 @@ def parse(cls, selectors: dict[str, dict[str, Any]]) -> YamlSelectors: for name, definition in selectors.items(): if not isinstance(definition, dict): - raise CosmosValueError( + error_message = ( f"Invalid selector definition for '{name}'. Expected a dict, got {type(definition)}: {definition}" ) + if lax_parsing_enabled: + logger.warning(error_message) + continue + else: + raise CosmosValueError(error_message) if not definition.get("name") or not definition.get("definition"): - raise CosmosValueError(f"Selector definition for '{name}' must contain 'name' and 'definition' keys.") + error_message = f"Selector definition for '{name}' must contain 'name' and 'definition' keys." + if lax_parsing_enabled: + logger.warning(error_message) + continue + else: + raise CosmosValueError( + f"Selector definition for '{name}' must contain 'name' and 'definition' keys." + ) selector_name = definition["name"] selector_definition = definition["definition"] diff --git a/cosmos/log.py b/cosmos/log.py index 15c58cfb2c..c56413dfb7 100644 --- a/cosmos/log.py +++ b/cosmos/log.py @@ -2,14 +2,13 @@ import logging -from cosmos.settings import rich_logging - class CosmosRichLogger(logging.Logger): """Custom Logger that prepends ``(astronomer-cosmos)`` to each log message in the scheduler.""" def handle(self, record: logging.LogRecord) -> None: - record.msg = "\x1b[35m(astronomer-cosmos)\x1b[0m " + record.msg + if record.msg is not None: + record.msg = "\x1b[35m(astronomer-cosmos)\x1b[0m " + str(record.msg) return super().handle(record) @@ -24,7 +23,14 @@ def get_logger(name: str) -> logging.Logger: as long as the ``rich_logging`` setting is True: [2023-08-09T14:20:55.532+0100] {subprocess.py:94} INFO - (astronomer-cosmos) - 13:20:55 Completed successfully """ - if rich_logging: + # Import cosmos.settings at call time (not module level) to avoid circular imports + # during Airflow plugin discovery. getattr tolerates the module being partially + # initialized (rich_logging not yet defined) by falling back to False. + import cosmos.settings as _settings + + _rich = getattr(_settings, "rich_logging", False) + + if _rich: cls = logging.getLoggerClass() try: logging.setLoggerClass(CosmosRichLogger) diff --git a/cosmos/operators/_watcher/__init__.py b/cosmos/operators/_watcher/__init__.py index 02ed3411e2..409eb1e175 100644 --- a/cosmos/operators/_watcher/__init__.py +++ b/cosmos/operators/_watcher/__init__.py @@ -1,12 +1,12 @@ from __future__ import annotations __all__ = [ - "get_xcom_val", - "safe_xcom_push", "build_producer_state_fetcher", - "is_dbt_node_status_success", + "get_xcom_val", "is_dbt_node_status_failed", + "is_dbt_node_status_success", "is_dbt_node_status_terminal", + "safe_xcom_push", "WatcherTrigger", ] diff --git a/cosmos/operators/_watcher/base.py b/cosmos/operators/_watcher/base.py index aa47e81291..62bae3de7f 100644 --- a/cosmos/operators/_watcher/base.py +++ b/cosmos/operators/_watcher/base.py @@ -4,7 +4,7 @@ import logging import threading from pathlib import Path -from typing import Any +from typing import TYPE_CHECKING, Any from airflow.exceptions import AirflowException, AirflowSkipException @@ -28,6 +28,7 @@ is_dbt_node_status_skipped, is_dbt_node_status_success, is_dbt_node_status_terminal, + is_producer_task_terminated, safe_xcom_push, xcom_set_lock, ) @@ -40,6 +41,15 @@ from airflow.sensors.base import BaseSensorOperator from airflow.utils.context import Context # type: ignore[attr-defined] +if TYPE_CHECKING: + from airflow.models.dag import DAG + from airflow.operators.empty import EmptyOperator + + try: + from airflow.sdk import TaskGroup + except ImportError: + from airflow.utils.task_group import TaskGroup + logger = get_logger(__name__) # Subset of dbt event types that represent errors/failures. @@ -413,12 +423,19 @@ def _fallback_to_non_watcher_run(self, try_number: int, context: Context) -> boo f"A future release will add fallback to local test execution." ) - logger.info( - f"Retry attempt #%s – Running model '%s' from project '%s' using {self.__class__.__name__}", - try_number - 1, - self.model_unique_id, - self.project_dir, - ) + if try_number > 1: + logger.info( + f"Retry attempt #%s – Running model '%s' from project '%s' using {self.__class__.__name__}", + try_number - 1, + self.model_unique_id, + self.project_dir, + ) + else: + logger.info( + f"Falling back to running model '%s' from project '%s' using {self.__class__.__name__}", + self.model_unique_id, + self.project_dir, + ) upstream_task = context["ti"].task.dag.get_task(self.producer_task_id) @@ -544,6 +561,16 @@ def execute_complete(self, context: Context, event: dict[str, Any]) -> None: f"Watcher producer task '{self.producer_task_id}' failed before reporting results for {self._resource_label.lower()} '{self.model_unique_id}'. Check its logs for the underlying error." ) + if reason == WatcherEventReason.PRODUCER_SKIPPED: + logger.info( + "Watcher producer task '%s' was skipped. Falling back to running dbt for %s '%s'.", + self.producer_task_id, + self._resource_label.lower(), + self.model_unique_id, + ) + self._fallback_to_non_watcher_run(try_number=context["ti"].try_number, context=context) + return + def _log_startup_events(self, ti: Any) -> None: dbt_startup_events: list[dict[str, Any]] = ti.xcom_pull( task_ids=self.producer_task_id, key=_DBT_STARTUP_EVENTS_XCOM_KEY @@ -583,6 +610,45 @@ def _cache_compiled_sql(self, ti: Any, context: Context) -> None: if hasattr(self, "_override_rtif"): self._override_rtif(context) + def _handle_retry(self, try_number: int, producer_task_state: str | None, context: Context) -> bool | None: + """Handle sensor retry by checking whether the producer is still active. + + Returns the fallback result if the producer has terminated, or None if + the sensor should continue polling (producer still active). + """ + if is_producer_task_terminated(producer_task_state): + # Producer finished — this is either an automatic retry after + # the producer completed or a manual task clear from the UI. + # Fall back to a non-watcher run. + return self._fallback_to_non_watcher_run(try_number, context) + # Producer is still active — the sensor likely timed out while the + # producer was still working. Keep polling instead of launching a + # duplicate dbt run. + logger.info( + "Try #%s but producer '%s' is still %s — continuing to poll instead of fallback.", + try_number, + self.producer_task_id, + producer_task_state or "unknown", + ) + return None + + def _handle_no_dbt_node_status(self, producer_task_state: str | None, try_number: int, context: Context) -> bool: + """Handle the case where no dbt node status has been reported yet.""" + if producer_task_state == "failed": + if self.poke_retry_number > 0: + raise AirflowException( + f"The dbt build command failed in producer task. Please check the log of task {self.producer_task_id} for details." + ) + else: + # This handles the scenario of tasks that failed with `State.UPSTREAM_FAILED` + return self._fallback_to_non_watcher_run(try_number, context) + + if producer_task_state == "skipped": + return self._fallback_to_non_watcher_run(try_number, context) + + self.poke_retry_number += 1 + return False + def poke(self, context: Context) -> bool: """ Checks the status of a dbt node (model or aggregated tests) by pulling relevant XComs from the producer task. @@ -605,10 +671,13 @@ def poke(self, context: Context) -> bool: self.model_unique_id, ) + producer_task_state = self._get_producer_task_status(context) + if try_number > 1: - return self._fallback_to_non_watcher_run(try_number, context) + retry_result = self._handle_retry(try_number, producer_task_state, context) + if retry_result is not None: + return retry_result - producer_task_state = self._get_producer_task_status(context) if not self.is_test_sensor: self._log_startup_events(ti) status = self._get_node_status(ti, context) @@ -625,19 +694,7 @@ def poke(self, context: Context) -> bool: _log_dbt_event(dbt_events) if status is None: - - if producer_task_state == "failed": - if self.poke_retry_number > 0: - raise AirflowException( - f"The dbt build command failed in producer task. Please check the log of task {self.producer_task_id} for details." - ) - else: - # This handles the scenario of tasks that failed with `State.UPSTREAM_FAILED` - return self._fallback_to_non_watcher_run(try_number, context) - - self.poke_retry_number += 1 - - return False + return self._handle_no_dbt_node_status(producer_task_state, try_number, context) elif is_dbt_node_status_skipped(status): raise AirflowSkipException( f"{self._resource_label} '{self.model_unique_id}' was skipped by the dbt command." @@ -646,3 +703,32 @@ def poke(self, context: Context) -> bool: return True else: raise AirflowException(f"{self._resource_label} '{self.model_unique_id}' finished with status '{status}'") + + +def create_producer_done_task(dag: DAG, task_group: TaskGroup, task_id: str) -> EmptyOperator: + """Create an EmptyOperator that absorbs the producer's skip state on retry. + + This task sits downstream of the producer inside the DbtTaskGroup. When the producer + is skipped on retry, this task still succeeds (trigger_rule=NONE_FAILED), preventing + the skip from propagating to tasks downstream of the group. + """ + try: + from airflow.providers.standard.operators.empty import EmptyOperator + except ImportError: + from airflow.operators.empty import EmptyOperator # type: ignore[no-redef] + + try: + from airflow.task.trigger_rule import TriggerRule + except ImportError: + from airflow.utils.trigger_rule import TriggerRule # type: ignore[no-redef] + + return EmptyOperator( # type: ignore[no-untyped-call] + task_id=task_id, + dag=dag, + task_group=task_group, + trigger_rule=TriggerRule.NONE_FAILED, + doc_md=( + "**Cosmos internal task.** Absorbs the producer's skip state on retry " + "so it does not propagate to tasks downstream of this DbtTaskGroup." + ), + ) diff --git a/cosmos/operators/_watcher/state.py b/cosmos/operators/_watcher/state.py index 27d3b488b0..744c5de322 100644 --- a/cosmos/operators/_watcher/state.py +++ b/cosmos/operators/_watcher/state.py @@ -25,6 +25,10 @@ DBT_FAILED_STATUSES = frozenset({"failed", "fail", "error", "runtime error"}) DBT_SKIPPED_STATUSES = frozenset({"skipped"}) +# Airflow task states that indicate the producer has finished and will not deliver any more XCom updates. +# Used to decide whether a sensor retry should fall back to a non-watcher run or keep polling. +PRODUCER_TERMINAL_STATES = frozenset({"success", "failed", "skipped", "upstream_failed", "removed"}) + class DbtTestStatus(str, Enum): """Aggregated status of all tests for a given model.""" @@ -55,6 +59,11 @@ def is_dbt_node_status_terminal(status: str | None) -> bool: return is_dbt_node_status_success(status) or is_dbt_node_status_failed(status) or is_dbt_node_status_skipped(status) +def is_producer_task_terminated(state: str | None) -> bool: + """Return True when the producer task is in a terminal state.""" + return state in PRODUCER_TERMINAL_STATES + + xcom_set_lock = Lock() @@ -68,6 +77,13 @@ def safe_xcom_push(task_instance: TaskInstance, key: str, value: Any) -> None: """ with xcom_set_lock: task_instance.xcom_push(key=key, value=value) + var_key = getattr(task_instance, "_cosmos_xcom_backup_var_key", None) + if isinstance(var_key, str): + from cosmos.operators._watcher.xcom import _persist_backup + + backup_buffer: dict[str, Any] = task_instance._cosmos_xcom_backup_buffer # type: ignore[attr-defined] + backup_buffer[key] = value + _persist_backup(var_key, backup_buffer) # TODO: Unify the Airflow call from cosmos/operators/_watcher/triggerer.py and cosmos/operators/watcher.py diff --git a/cosmos/operators/_watcher/triggerer.py b/cosmos/operators/_watcher/triggerer.py index c104a3c334..cd88099e96 100644 --- a/cosmos/operators/_watcher/triggerer.py +++ b/cosmos/operators/_watcher/triggerer.py @@ -29,6 +29,7 @@ class WatcherEventReason(str, Enum): NODE_FAILED = "node_failed" PRODUCER_FAILED = "producer_failed" + PRODUCER_SKIPPED = "producer_skipped" NODE_NOT_RUN = "node_not_run" @@ -201,7 +202,7 @@ async def _log_startup_events(self) -> None: # logging the full list. Also ensures we exit if producer finishes before # ever pushing _DBT_STARTUP_EVENTS_XCOM_KEY. producer_task_state = await self._get_producer_task_status() - if producer_task_state in ("failed", "success"): + if producer_task_state in ("failed", "success", "skipped"): return # Return if dbt node is in terminal state @@ -211,6 +212,16 @@ async def _log_startup_events(self) -> None: await asyncio.sleep(self.poke_interval) + def _build_success_event(self, compiled_sql: str | None) -> dict[str, Any]: + """Build the TriggerEvent payload for a successful dbt node.""" + event_data: dict[str, Any] = {"status": EventStatus.SUCCESS} + if compiled_sql: + event_data["compiled_sql"] = compiled_sql + outlet_uris = getattr(self, "_outlet_uris", []) + if outlet_uris: + event_data["outlet_uris"] = outlet_uris + return event_data + async def run(self) -> AsyncIterator[TriggerEvent]: logger.info("Starting WatcherTrigger for model: %s", self.model_unique_id) await self._log_startup_events() @@ -222,14 +233,7 @@ async def run(self) -> AsyncIterator[TriggerEvent]: dbt_node_status, compiled_sql = await self._parse_dbt_node_status_and_compiled_sql() if is_dbt_node_status_success(dbt_node_status): logger.debug("dbt node '%s' succeeded", self.model_unique_id) - event_data: dict[str, Any] = {"status": EventStatus.SUCCESS} - if compiled_sql: - event_data["compiled_sql"] = compiled_sql - # Pass outlet URIs through TriggerEvent so consumer can emit datasets - outlet_uris = getattr(self, "_outlet_uris", []) - if outlet_uris: - event_data["outlet_uris"] = outlet_uris - yield TriggerEvent(event_data) # type: ignore[no-untyped-call] + yield TriggerEvent(self._build_success_event(compiled_sql)) # type: ignore[no-untyped-call] return elif is_dbt_node_status_skipped(dbt_node_status): logger.info("dbt node '%s' was skipped", self.model_unique_id) @@ -250,6 +254,15 @@ async def run(self) -> AsyncIterator[TriggerEvent]: ) yield TriggerEvent({"status": EventStatus.FAILED, "reason": WatcherEventReason.PRODUCER_FAILED}) # type: ignore[no-untyped-call] return + elif producer_task_state == "skipped": + logger.info( + "Watcher producer task '%s' was skipped (e.g. retry). " + "Consumer will fall back to running dbt for node '%s'.", + self.producer_task_id, + self.model_unique_id, + ) + yield TriggerEvent({"status": EventStatus.FAILED, "reason": WatcherEventReason.PRODUCER_SKIPPED}) # type: ignore[no-untyped-call] + return elif producer_task_state == "success" and dbt_node_status is None: logger.info( "The producer task '%s' succeeded. There is no information about the node '%s' execution.", diff --git a/cosmos/operators/_watcher/xcom.py b/cosmos/operators/_watcher/xcom.py new file mode 100644 index 0000000000..a2b892f406 --- /dev/null +++ b/cosmos/operators/_watcher/xcom.py @@ -0,0 +1,126 @@ +"""XCom backup/restore logic for watcher producer retry resilience. + +When the watcher producer fails and retries, Airflow clears XCom entries from +the previous attempt. This module provides an incremental backup mechanism +that persists XCom key/value pairs to an Airflow Variable so they can be +restored on retry. +""" + +from __future__ import annotations + +import base64 +import json +import zlib +from typing import Any + +from cosmos.log import get_logger +from cosmos.operators._watcher.state import safe_xcom_push + +logger = get_logger(__name__) + +XCOM_BACKUP_VARIABLE_PREFIX = "cosmos_xcom_backup" + + +def _xcom_backup_variable_key(dag_id: str, task_group_id: str | None, run_id: str) -> str: + """Build a unique Airflow Variable key for the XCom backup of a watcher producer run.""" + parts = [XCOM_BACKUP_VARIABLE_PREFIX, dag_id.replace(".", "___")] + if task_group_id: + parts.append(task_group_id.replace(".", "__")) + parts.append(run_id.replace(".", "_")) + return "__".join(parts) + + +def _get_task_group_id(ti: Any) -> str | None: + """Extract the task_group_id from a task instance, if available.""" + task = getattr(ti, "task", None) + return getattr(task, "task_group_id", None) if task else None + + +def _init_xcom_backup(context: Any) -> None: + """Activate incremental XCom backup for the current producer execution. + + After this call every ``safe_xcom_push`` will also persist the key/value + pair to an Airflow Variable so the backup is crash-safe. The state is + stored on the task instance itself — no module-level globals. + """ + ti = context["ti"] + dag_id = ti.dag_id + run_id = context["run_id"] + task_group_id = _get_task_group_id(ti) + ti._cosmos_xcom_backup_var_key = _xcom_backup_variable_key(dag_id, task_group_id, run_id) # type: ignore[attr-defined] + ti._cosmos_xcom_backup_buffer = {} # type: ignore[attr-defined] + + +def _persist_backup(var_key: str, backup_buffer: dict[str, Any]) -> None: + """Write the current backup buffer to an Airflow Variable.""" + if not backup_buffer: + return + + from airflow.models import Variable + + compressed = base64.b64encode(zlib.compress(json.dumps(backup_buffer, default=str).encode("utf-8"))).decode("utf-8") + Variable.set(var_key, compressed) + logger.debug("Persisted %d XCom entries to Variable '%s'", len(backup_buffer), var_key) + + +def _backup_xcom_to_variable(context: Any) -> None: + """Persist all XCom entries for the current producer task in an Airflow Variable. + + The backup is built incrementally by ``safe_xcom_push`` (see + ``_init_xcom_backup``). This call does a final flush. + """ + ti = context["ti"] + var_key = getattr(ti, "_cosmos_xcom_backup_var_key", None) + if not isinstance(var_key, str): + return + + backup_buffer = getattr(ti, "_cosmos_xcom_backup_buffer", {}) + _persist_backup(var_key, backup_buffer) + if backup_buffer: + logger.debug("Backed up %d XCom entries to Variable '%s'", len(backup_buffer), var_key) + + +def _delete_xcom_backup_variable(context: Any) -> None: + """Delete the XCom backup Variable after a successful producer execution.""" + ti = context["ti"] + var_key = getattr(ti, "_cosmos_xcom_backup_var_key", None) + if not isinstance(var_key, str): + return + try: + from airflow.models import Variable + + Variable.delete(var_key) + logger.debug("Deleted XCom backup Variable '%s'", var_key) + except KeyError: + pass + + +def _restore_xcom_from_variable(context: Any) -> bool: + """Restore XCom entries from an Airflow Variable backup created by a previous attempt. + + Returns True if the restore succeeded, False if no backup was found. + """ + from airflow.models import Variable + + ti = context["ti"] + dag_id = ti.dag_id + run_id = context["run_id"] + task_group_id = _get_task_group_id(ti) + + var_key = _xcom_backup_variable_key(dag_id, task_group_id, run_id) + compressed = Variable.get(var_key, default_var=None) + if compressed is None: + logger.info("No XCom backup Variable found at '%s'", var_key) + return False + + backup: dict[str, Any] = json.loads(zlib.decompress(base64.b64decode(compressed.encode("utf-8"))).decode("utf-8")) + for key, value in backup.items(): + safe_xcom_push(task_instance=ti, key=key, value=value) + logger.info("Restored %d XCom entries from Variable '%s'", len(backup), var_key) + + try: + Variable.delete(var_key) + logger.debug("Deleted XCom backup Variable '%s' after restore", var_key) + except KeyError: + logger.debug("XCom backup Variable '%s' already deleted", var_key) + return True diff --git a/cosmos/operators/local.py b/cosmos/operators/local.py index b77c112099..4fed25075a 100644 --- a/cosmos/operators/local.py +++ b/cosmos/operators/local.py @@ -690,7 +690,7 @@ def run_command( # noqa: C901 self.handle_exception(result) return result if is_openlineage_common_available: - self.calculate_openlineage_events_completes(env, tmp_dir_path) + self.calculate_openlineage_events_completes(env, tmp_dir_path, full_cmd) if AIRFLOW_VERSION.major < _AIRFLOW3_MAJOR_VERSION: # Airflow 3 does not support associating 'openlineage_events_completes' with task_instance, # in that case we're storing as self.openlineage_events_completes @@ -711,7 +711,10 @@ def run_command( # noqa: C901 return result def calculate_openlineage_events_completes( - self, env: dict[str, str | os.PathLike[Any] | bytes], project_dir: Path + self, + env: dict[str, str | os.PathLike[Any] | bytes], + project_dir: Path, + dbt_command_line: list[str] | None = None, ) -> None: """ Use openlineage-integration-common to extract lineage events from the artifacts generated after running the dbt @@ -728,13 +731,17 @@ def calculate_openlineage_events_completes( for key, value in env.items(): os.environ[key] = str(value) - openlineage_processor = DbtLocalArtifactProcessor( + processor_kwargs: dict[str, Any] = dict( producer=OPENLINEAGE_PRODUCER, job_namespace=settings.LINEAGE_NAMESPACE, project_dir=project_dir, profile_name=self.profile_config.profile_name, target=self.profile_config.target_name, ) + sig = inspect.signature(DbtLocalArtifactProcessor.__init__) + if "dbt_command_line" in sig.parameters and dbt_command_line is not None: + processor_kwargs["dbt_command_line"] = dbt_command_line + openlineage_processor = DbtLocalArtifactProcessor(**processor_kwargs) # Do not raise exception if a command is unsupported, following the openlineage-dbt processor: # https://github.com/OpenLineage/OpenLineage/blob/bdcaf828ebc117e0e5ffc5fab44ff8886eb7836b/integration/common/openlineage/common/provider/dbt/processor.py#L141 openlineage_processor.should_raise_on_unsupported_command = False diff --git a/cosmos/operators/watcher.py b/cosmos/operators/watcher.py index b007a4f433..70169e6f5a 100644 --- a/cosmos/operators/watcher.py +++ b/cosmos/operators/watcher.py @@ -5,7 +5,7 @@ from collections.abc import Callable from typing import TYPE_CHECKING, Any -from airflow.exceptions import AirflowException +from airflow.exceptions import AirflowException, AirflowSkipException try: # Airflow 3.1 onwards @@ -28,6 +28,12 @@ BaseConsumerSensor, store_dbt_resource_status_from_log, ) +from cosmos.operators._watcher.xcom import ( + _backup_xcom_to_variable, + _delete_xcom_backup_variable, + _init_xcom_backup, + _restore_xcom_from_variable, +) from cosmos.operators.base import ( DbtBuildMixin, DbtRunMixin, @@ -62,10 +68,16 @@ def _default_freshness_callback( nodes: dict[str, DbtNode] | None, sources_json: dict[str, Any] | None, ) -> tuple[list[str], str]: - """Return unique_ids of all nodes that transitively depend on a stale source, plus the status ``"skip"``. + """Return unique_ids of nodes that must be skipped due to stale sources, plus the status ``"skip"``. Stale sources are those with ``status`` of ``"error"`` or ``"warn"`` in ``sources_json["results"]``. - Traversal is DFS over the reverse-dependency graph built from ``nodes``. + + A node is skipped only when **all** of its upstream dependencies are either stale sources or + already-skipped nodes. If a node has at least one clean upstream path it may still execute + successfully — for example when a model depends on both a stale source and a clean model — so + it is excluded from the skip set and allowed to run. + + Traversal is a DFS over the reverse-dependency graph built from ``nodes``. """ if not nodes or not sources_json: return [], "skip" @@ -80,14 +92,23 @@ def _default_freshness_callback( for dep_id in node.depends_on: dependents.setdefault(dep_id, set()).add(node_id) - # DFS from each stale source to collect all transitive dependents + # DFS from each stale source. A dependent is added to the skip set only when every entry in + # its depends_on is either a known-stale source or already in the skip set. This preserves + # nodes that have at least one clean upstream path: they may succeed and should not be + # preemptively excluded. When a new node is added to visited its own dependents are queued + # so they can be re-evaluated with the updated skip set. _excludable_resource_types = {DbtResourceType.MODEL, DbtResourceType.SEED, DbtResourceType.SNAPSHOT} visited: set[str] = set() queue = list(stale_source_ids) while queue: current = queue.pop() for dependent_id in dependents.get(current, set()): - if dependent_id not in visited: + if dependent_id in visited: + continue + dependent_node = nodes.get(dependent_id) + if dependent_node is None: + continue + if all(dep in stale_source_ids or dep in visited for dep in dependent_node.depends_on): visited.add(dependent_id) queue.append(dependent_id) @@ -212,9 +233,21 @@ def run_dbt_runner(self, command: list[str], env: dict[str, str], cwd: str, **kw When a callback is registered, dbt's stdout is redirected to a null buffer so that the raw ``--log-format json`` lines do not appear in Airflow task logs alongside the human-readable messages already emitted by ``_log_dbt_msg`` inside the callback. + + The callback is intentionally **not** registered during the source freshness pre-check + (``context["_check_source_freshness"] is True``). Registering it there would leave a + stale entry in ``_dbt_runner_callbacks`` that fires again for every event during the + subsequent ``dbt build``, producing duplicate log lines. Freshness results are read from + ``target/sources.json`` after the run and do not need per-event XCom pushes. """ context = kwargs.get("context") if context is not None: + # During the source freshness pre-check suppress raw JSON stdout, but do not register + # the XCom-pushing callback so it cannot accumulate and duplicate build logs later. + if context.get("_check_source_freshness"): + with contextlib.redirect_stdout(_NullWriter()): + return super().run_dbt_runner(command, env, cwd, **kwargs) + extra_kwargs: dict[str, Any] = {"project_dir": cwd, "context": context} parse = self._make_parse_callable() # Collect callback errors rather than raising inside the callback: dbt catches @@ -356,12 +389,13 @@ def execute(self, context: Context, **kwargs: Any) -> Any: try_number = getattr(task_instance, "try_number", 1) if try_number > 1: - self.log.info( + _restore_xcom_from_variable(context) + raise AirflowSkipException( "Dbt WATCHER producer task does not support Airflow retries. " - "Detected attempt #%s; skipping execution to avoid running a second dbt build.", - try_number, + f"Detected attempt #{try_number}; skipping execution to avoid running a second dbt build." ) - return None + + _init_xcom_backup(context) if self._check_source_freshness: self._apply_source_freshness(context) @@ -369,10 +403,12 @@ def execute(self, context: Context, **kwargs: Any) -> Any: try: return_value = super().execute(context=context, **kwargs) safe_xcom_push(task_instance=context["ti"], key="task_status", value="completed") + _delete_xcom_backup_variable(context) return return_value except Exception: safe_xcom_push(task_instance=context["ti"], key="task_status", value="completed") + _backup_xcom_to_variable(context) raise diff --git a/cosmos/operators/watcher_kubernetes.py b/cosmos/operators/watcher_kubernetes.py index 1d72defc9e..3cab4835e4 100644 --- a/cosmos/operators/watcher_kubernetes.py +++ b/cosmos/operators/watcher_kubernetes.py @@ -13,7 +13,7 @@ from airflow.utils.context import Context # type: ignore[attr-defined] import kubernetes.client as k8s -from airflow.exceptions import AirflowException +from airflow.exceptions import AirflowException, AirflowSkipException from airflow.providers.cncf.kubernetes.callbacks import KubernetesPodOperatorCallback, client_type try: @@ -24,6 +24,12 @@ from cosmos.airflow._override import CosmosKubernetesPodManager from cosmos.log import get_logger from cosmos.operators._watcher.base import BaseConsumerSensor, store_dbt_resource_status_from_log +from cosmos.operators._watcher.xcom import ( + _backup_xcom_to_variable, + _delete_xcom_backup_variable, + _init_xcom_backup, + _restore_xcom_from_variable, +) from cosmos.operators.base import ( DbtRunMixin, DbtSeedMixin, @@ -107,18 +113,26 @@ def execute(self, context: Context, **kwargs: Any) -> Any: try_number = getattr(task_instance, "try_number", 1) if try_number > 1: - self.log.info( + _restore_xcom_from_variable(context) + raise AirflowSkipException( "DbtProducerWatcherKubernetesOperator does not support Airflow retries. " - "Detected attempt #%s; skipping execution to avoid running a second dbt build.", - try_number, + f"Detected attempt #{try_number}; skipping execution to avoid running a second dbt build." ) - return None + + _init_xcom_backup(context) # This global variable is used to make the task context available to the K8s callback. # While the callback is set during the operator initialization, the context is only created during the operator's execution. global producer_task_context producer_task_context = context - return super().execute(context, **kwargs) + + try: + return_value = super().execute(context, **kwargs) + _delete_xcom_backup_variable(context) + return return_value + except Exception: + _backup_xcom_to_variable(context) + raise class DbtConsumerWatcherKubernetesSensor(BaseConsumerSensor, DbtRunKubernetesOperator): diff --git a/cosmos/settings.py b/cosmos/settings.py index 5156f6a557..40255fac83 100644 --- a/cosmos/settings.py +++ b/cosmos/settings.py @@ -32,6 +32,7 @@ enable_cache_package_lockfile = conf.getboolean("cosmos", "enable_cache_package_lockfile", fallback=True) enable_cache_dbt_ls = conf.getboolean("cosmos", "enable_cache_dbt_ls", fallback=True) enable_cache_dbt_yaml_selectors = conf.getboolean("cosmos", "enable_cache_dbt_yaml_selectors", fallback=True) +enable_lax_selector_parsing = conf.getboolean("cosmos", "enable_lax_selector_parsing", fallback=False) rich_logging = conf.getboolean("cosmos", "rich_logging", fallback=False) dbt_docs_dir = conf.get("cosmos", "dbt_docs_dir", fallback=None) dbt_docs_conn_id = conf.get("cosmos", "dbt_docs_conn_id", fallback=None) diff --git a/dev/Dockerfile.watcher_failing_tests_k8s b/dev/Dockerfile.watcher_failing_tests_k8s new file mode 100644 index 0000000000..b7ca917463 --- /dev/null +++ b/dev/Dockerfile.watcher_failing_tests_k8s @@ -0,0 +1,19 @@ +FROM python:3.12 + +RUN pip install --force-reinstall 'dbt-postgres>=1.8' +RUN pip install --force-reinstall dbt-adapters +RUN pip install --force-reinstall awscli + +ENV POSTGRES_DATABASE=postgres +ENV POSTGRES_DB=postgres +ENV POSTGRES_HOST=postgres.default.svc.cluster.local +ENV POSTGRES_PASSWORD=postgres +ENV POSTGRES_PORT=5432 +ENV POSTGRES_SCHEMA=public +ENV POSTGRES_USER=postgres + +RUN mkdir /root/.dbt +COPY dags/dbt/jaffle_shop/profiles.yml /root/.dbt/profiles.yml + +RUN mkdir -p dags/dbt +COPY dags/dbt/watcher_failing_tests dags/dbt/watcher_failing_tests diff --git a/dev/dags/cosmos_manifest_selectors_example.py b/dev/dags/cosmos_manifest_selectors_example.py index 08ac9acfd0..b9aaf2602f 100644 --- a/dev/dags/cosmos_manifest_selectors_example.py +++ b/dev/dags/cosmos_manifest_selectors_example.py @@ -58,7 +58,7 @@ profile_config=profile_config, render_config=RenderConfig( load_method=LoadMode.DBT_MANIFEST, - select=["+customers"], + select=["+customers", "+orders", "raw_orders", "raw_payments"], airflow_vars_to_purge_dbt_yaml_selectors_cache=["purge"], ), execution_config=execution_config, diff --git a/dev/dags/dbt/jaffle_shop/dbt_project.yml b/dev/dags/dbt/jaffle_shop/dbt_project.yml index ba052995db..4f94ca070a 100644 --- a/dev/dags/dbt/jaffle_shop/dbt_project.yml +++ b/dev/dags/dbt/jaffle_shop/dbt_project.yml @@ -17,7 +17,7 @@ clean-targets: - "dbt_modules" - "logs" -require-dbt-version: [">=1.0.0", "<2.0.0"] +require-dbt-version: [">=1.0.0", "<3.0.0"] models: jaffle_shop: diff --git a/dev/dags/dbt/jaffle_shop/package-lock.yml b/dev/dags/dbt/jaffle_shop/package-lock.yml deleted file mode 100644 index 669f9637dc..0000000000 --- a/dev/dags/dbt/jaffle_shop/package-lock.yml +++ /dev/null @@ -1,4 +0,0 @@ -packages: - - package: dbt-labs/dbt_utils - version: 1.1.1 -sha1_hash: a158c48c59c2bb7d729d2a4e215aabe5bb4f3353 diff --git a/dev/dags/dbt/jaffle_shop/packages.yml b/dev/dags/dbt/jaffle_shop/packages.yml index 3af907c8aa..5986d47f9a 100644 --- a/dev/dags/dbt/jaffle_shop/packages.yml +++ b/dev/dags/dbt/jaffle_shop/packages.yml @@ -1,3 +1,3 @@ packages: - package: dbt-labs/dbt_utils - version: "1.1.1" + version: [">=1.1.1", "<1.4.0"] diff --git a/dev/dags/dbt/multi_folder_failing/.gitignore b/dev/dags/dbt/multi_folder_failing/.gitignore new file mode 100644 index 0000000000..bc0684eb2c --- /dev/null +++ b/dev/dags/dbt/multi_folder_failing/.gitignore @@ -0,0 +1,4 @@ +target/ +dbt_packages/ +logs/ +dbt_internal_packages/ diff --git a/dev/dags/dbt/multi_folder_failing/dbt_project.yml b/dev/dags/dbt/multi_folder_failing/dbt_project.yml new file mode 100644 index 0000000000..3f7a254b7f --- /dev/null +++ b/dev/dags/dbt/multi_folder_failing/dbt_project.yml @@ -0,0 +1,23 @@ +name: 'multi_folder_failing' +version: '0.1.0' +config-version: 2 + +profile: 'default' + +seed-paths: ["seeds"] +model-paths: ["models"] + +test-paths: ["tests"] +target-path: "target" +clean-targets: + - "target" + - "logs" + +require-dbt-version: [">=1.0.0", "<2.0.0"] + +models: + multi_folder_failing: + models_a: + +materialized: view + models_b: + +materialized: view diff --git a/dev/dags/dbt/multi_folder_failing/models/models_a/dim_products.sql b/dev/dags/dbt/multi_folder_failing/models/models_a/dim_products.sql new file mode 100644 index 0000000000..00f3cb9d6e --- /dev/null +++ b/dev/dags/dbt/multi_folder_failing/models/models_a/dim_products.sql @@ -0,0 +1,6 @@ +-- Depends on model: models_a/stg_products.sql +select + product_id, + product_name, + upper(product_name) as product_name_upper +from {{ ref('stg_products') }} diff --git a/dev/dags/dbt/multi_folder_failing/models/models_a/stg_products.sql b/dev/dags/dbt/multi_folder_failing/models/models_a/stg_products.sql new file mode 100644 index 0000000000..6186531e5a --- /dev/null +++ b/dev/dags/dbt/multi_folder_failing/models/models_a/stg_products.sql @@ -0,0 +1,5 @@ +-- Depends on seed: seeds_a/products.csv +select + id as product_id, + name as product_name +from {{ ref('products') }} diff --git a/dev/dags/dbt/multi_folder_failing/models/models_b/dim_failing.sql b/dev/dags/dbt/multi_folder_failing/models/models_b/dim_failing.sql new file mode 100644 index 0000000000..0e0e2654e1 --- /dev/null +++ b/dev/dags/dbt/multi_folder_failing/models/models_b/dim_failing.sql @@ -0,0 +1,6 @@ +-- Intentionally broken model: references a nonexistent column +select + region_id, + region_name, + this_column_does_not_exist +from {{ ref('stg_regions') }} diff --git a/dev/dags/dbt/multi_folder_failing/models/models_b/stg_regions.sql b/dev/dags/dbt/multi_folder_failing/models/models_b/stg_regions.sql new file mode 100644 index 0000000000..076595d262 --- /dev/null +++ b/dev/dags/dbt/multi_folder_failing/models/models_b/stg_regions.sql @@ -0,0 +1,7 @@ +-- Depends on seeds: seeds_b/regions.csv and seeds_b/region_managers.csv +select + r.id as region_id, + r.name as region_name, + m.manager_name +from {{ ref('regions') }} r +left join {{ ref('region_managers') }} m on r.id = m.region_id diff --git a/dev/dags/dbt/multi_folder_failing/profiles.yml b/dev/dags/dbt/multi_folder_failing/profiles.yml new file mode 100644 index 0000000000..224f565f4a --- /dev/null +++ b/dev/dags/dbt/multi_folder_failing/profiles.yml @@ -0,0 +1,12 @@ +default: + target: dev + outputs: + dev: + type: postgres + host: "{{ env_var('POSTGRES_HOST') }}" + user: "{{ env_var('POSTGRES_USER') }}" + password: "{{ env_var('POSTGRES_PASSWORD') }}" + port: "{{ env_var('POSTGRES_PORT') | int }}" + dbname: "{{ env_var('POSTGRES_DB') }}" + schema: "{{ env_var('POSTGRES_SCHEMA') }}" + threads: 4 diff --git a/dev/dags/dbt/multi_folder_failing/seeds/seeds_a/products.csv b/dev/dags/dbt/multi_folder_failing/seeds/seeds_a/products.csv new file mode 100644 index 0000000000..76ca75993a --- /dev/null +++ b/dev/dags/dbt/multi_folder_failing/seeds/seeds_a/products.csv @@ -0,0 +1,4 @@ +id,name +1,Widget +2,Gadget +3,Gizmo diff --git a/dev/dags/dbt/multi_folder_failing/seeds/seeds_b/region_managers.csv b/dev/dags/dbt/multi_folder_failing/seeds/seeds_b/region_managers.csv new file mode 100644 index 0000000000..b114874fad --- /dev/null +++ b/dev/dags/dbt/multi_folder_failing/seeds/seeds_b/region_managers.csv @@ -0,0 +1,11 @@ +region_id,manager_name +1,Alice +2,Bob +3,Carol +4,Dave +5,Eve +6,Frank +7,Grace +8,Henry +9,Ivy +10,Jack diff --git a/dev/dags/dbt/multi_folder_failing/seeds/seeds_b/regions.csv b/dev/dags/dbt/multi_folder_failing/seeds/seeds_b/regions.csv new file mode 100644 index 0000000000..da206a7611 --- /dev/null +++ b/dev/dags/dbt/multi_folder_failing/seeds/seeds_b/regions.csv @@ -0,0 +1,11 @@ +id,name +1,North +2,South +3,East +4,West +5,Central +6,Northeast +7,Northwest +8,Southeast +9,Southwest +10,Midwest diff --git a/dev/dags/dbt/watcher_downstream_not_skipped/dbt_project.yml b/dev/dags/dbt/watcher_downstream_not_skipped/dbt_project.yml new file mode 100644 index 0000000000..196cb8ba40 --- /dev/null +++ b/dev/dags/dbt/watcher_downstream_not_skipped/dbt_project.yml @@ -0,0 +1,26 @@ +name: 'watcher_downstream_not_skipped' + +config-version: 2 +version: '0.1' + +profile: 'default' + +model-paths: ["models"] +seed-paths: ["seeds"] +test-paths: ["tests"] +macro-paths: ["macros"] + +target-path: "target" +clean-targets: + - "target" + - "dbt_modules" + - "logs" + +require-dbt-version: [">=1.0.0", "<2.0.0"] + +on-run-start: + - "CREATE SEQUENCE IF NOT EXISTS {{ target.schema }}._cosmos_fail_once_seq" + +models: + watcher_downstream_not_skipped: + +materialized: table diff --git a/dev/dags/dbt/watcher_downstream_not_skipped/models/model_a.sql b/dev/dags/dbt/watcher_downstream_not_skipped/models/model_a.sql new file mode 100644 index 0000000000..2031e20406 --- /dev/null +++ b/dev/dags/dbt/watcher_downstream_not_skipped/models/model_a.sql @@ -0,0 +1,5 @@ +select 1 as id, 'Alice' as first_name, 'Smith' as last_name, 'alice@example.com' as email +union all +select 2, 'Bob', 'Jones', 'bob@example.com' +union all +select 3, 'Charlie', 'Brown', 'charlie@example.com' diff --git a/dev/dags/dbt/watcher_downstream_not_skipped/models/model_retry.sql b/dev/dags/dbt/watcher_downstream_not_skipped/models/model_retry.sql new file mode 100644 index 0000000000..bf5f905bb3 --- /dev/null +++ b/dev/dags/dbt/watcher_downstream_not_skipped/models/model_retry.sql @@ -0,0 +1,12 @@ +{{ + config( + pre_hook=[ + "DO $$ BEGIN IF nextval('{{ target.schema }}._cosmos_fail_once_seq') <= 1 THEN RAISE EXCEPTION 'fail_once: intentional first-run failure'; END IF; END $$" + ] + ) +}} + +select + id, + first_name +from {{ ref('model_a') }} diff --git a/dev/dags/dbt/watcher_failing_tests/dbt_project.yml b/dev/dags/dbt/watcher_failing_tests/dbt_project.yml new file mode 100644 index 0000000000..dae8a58e82 --- /dev/null +++ b/dev/dags/dbt/watcher_failing_tests/dbt_project.yml @@ -0,0 +1,23 @@ +name: 'watcher_failing_tests' + +config-version: 2 +version: '0.1' + +profile: 'default' + +model-paths: ["models"] +seed-paths: ["seeds"] +test-paths: ["tests"] +macro-paths: ["macros"] + +target-path: "target" +clean-targets: + - "target" + - "dbt_modules" + - "logs" + +require-dbt-version: [">=1.0.0", "<2.0.0"] + +models: + watcher_failing_tests: + +materialized: table diff --git a/dev/dags/dbt/watcher_failing_tests/models/model_a.sql b/dev/dags/dbt/watcher_failing_tests/models/model_a.sql new file mode 100644 index 0000000000..2031e20406 --- /dev/null +++ b/dev/dags/dbt/watcher_failing_tests/models/model_a.sql @@ -0,0 +1,5 @@ +select 1 as id, 'Alice' as first_name, 'Smith' as last_name, 'alice@example.com' as email +union all +select 2, 'Bob', 'Jones', 'bob@example.com' +union all +select 3, 'Charlie', 'Brown', 'charlie@example.com' diff --git a/dev/dags/dbt/watcher_failing_tests/models/model_f.sql b/dev/dags/dbt/watcher_failing_tests/models/model_f.sql new file mode 100644 index 0000000000..892de7f402 --- /dev/null +++ b/dev/dags/dbt/watcher_failing_tests/models/model_f.sql @@ -0,0 +1,5 @@ +select + id, + first_name, + this_column_does_not_exist_at_all +from {{ ref('model_a') }} diff --git a/dev/failed_dags/example_watcher_downstream_not_skipped.py b/dev/failed_dags/example_watcher_downstream_not_skipped.py new file mode 100644 index 0000000000..417b2336dd --- /dev/null +++ b/dev/failed_dags/example_watcher_downstream_not_skipped.py @@ -0,0 +1,93 @@ +""" +Airflow DAG to verify watcher downstream-task behavior when a producer task is skipped on retry. + +In watcher mode, when a dbt model fails the producer retries and raises AirflowSkipException. +Without the gateway task introduced in Cosmos, the skip would propagate to all tasks downstream +of the TaskGroup (e.g. post_dbt), even though the consumer tasks inside the group succeeded. + +The gateway task (dbt_producer_watcher_done) sits downstream of the producer inside the group +with trigger_rule="none_failed", absorbing the skip so it does not propagate outside. + +This DAG demonstrates that: +- model_a succeeds in the producer and the consumer reads the result from XCom +- model_retry fails on the first attempt (via a fail_once sequence pre-hook) but succeeds + on the consumer retry fallback +- post_dbt (an EmptyOperator downstream of the group) runs successfully — it is NOT skipped + +A cleanup task drops the PostgreSQL sequence so the DAG can be re-triggered. +""" + +import os +from datetime import datetime, timedelta +from pathlib import Path + +from airflow.models import DAG +from airflow.providers.common.sql.operators.sql import SQLExecuteQueryOperator + +try: + from airflow.providers.standard.operators.empty import EmptyOperator +except ImportError: + from airflow.operators.empty import EmptyOperator + +from cosmos import DbtTaskGroup, ExecutionConfig, ProfileConfig, ProjectConfig +from cosmos.constants import ExecutionMode +from cosmos.profiles import PostgresUserPasswordProfileMapping + +DEFAULT_DBT_ROOT_PATH = Path(__file__).parent.parent / "dags/dbt" +DBT_ROOT_PATH = Path(os.getenv("DBT_ROOT_PATH", DEFAULT_DBT_ROOT_PATH)) +DBT_PROJECT_PATH = DBT_ROOT_PATH / "watcher_downstream_not_skipped" + +profile_config = ProfileConfig( + profile_name="default", + target_name="dev", + profile_mapping=PostgresUserPasswordProfileMapping( + conn_id="example_conn", + profile_args={"schema": "public"}, + disable_event_tracking=True, + ), +) + +execution_config = ExecutionConfig( + execution_mode=ExecutionMode.WATCHER, +) + +operator_args = { + "install_deps": True, + "execution_timeout": timedelta(seconds=120), +} + +if os.getenv("CI"): + operator_args["trigger_rule"] = "all_success" + +default_args = { + "retries": 2, + "retry_delay": timedelta(seconds=0), + "depends_on_past": True, +} + +with DAG( + dag_id="example_watcher_downstream_not_skipped", + schedule="@daily", + start_date=datetime(2023, 1, 1), + catchup=False, + default_args=default_args, +): + dbt_group = DbtTaskGroup( + group_id="watcher_downstream_not_skipped", + execution_config=execution_config, + project_config=ProjectConfig(DBT_PROJECT_PATH), + profile_config=profile_config, + operator_args=operator_args, + ) + + post_dbt = EmptyOperator(task_id="post_dbt") + + cleanup = SQLExecuteQueryOperator( + task_id="drop_fail_once_marker", + conn_id="example_conn", + sql="DROP SEQUENCE IF EXISTS public._cosmos_fail_once_seq;", + trigger_rule="all_done", + ) + + dbt_group >> post_dbt + dbt_group >> cleanup diff --git a/dev/failed_dags/example_watcher_failing_multi_folder.py b/dev/failed_dags/example_watcher_failing_multi_folder.py new file mode 100644 index 0000000000..a8fe0470d8 --- /dev/null +++ b/dev/failed_dags/example_watcher_failing_multi_folder.py @@ -0,0 +1,62 @@ +""" +DAG to test watcher retry behaviour with group_nodes_by_folder=True. + +Uses the multi_folder_failing dbt project with a failing model (dim_failing in models_b). +""" + +import os +from datetime import datetime, timedelta +from pathlib import Path + +from cosmos import DbtDag, ExecutionConfig, ProfileConfig, ProjectConfig, RenderConfig +from cosmos.constants import ExecutionMode, TestBehavior +from cosmos.profiles import PostgresUserPasswordProfileMapping + +DEFAULT_DBT_ROOT_PATH = Path(__file__).parent.parent / "dags/dbt" +DBT_ROOT_PATH = Path(os.getenv("DBT_ROOT_PATH", DEFAULT_DBT_ROOT_PATH)) +DBT_PROJECT_PATH = DBT_ROOT_PATH / "multi_folder_failing" + +profile_config = ProfileConfig( + profile_name="default", + target_name="dev", + profile_mapping=PostgresUserPasswordProfileMapping( + conn_id="example_conn", + profile_args={"schema": "public"}, + disable_event_tracking=True, + ), +) + +execution_config = ExecutionConfig( + execution_mode=ExecutionMode.WATCHER, +) + +render_config = RenderConfig( + group_nodes_by_folder=True, + test_behavior=TestBehavior.NONE, +) + +operator_args = { + "install_deps": True, + "execution_timeout": timedelta(seconds=120), +} + +if os.getenv("CI"): + operator_args["trigger_rule"] = "all_success" + +default_args = { + "retries": 2, + "retry_delay": timedelta(seconds=0), +} + +example_watcher_failing_multi_folder = DbtDag( + execution_config=execution_config, + project_config=ProjectConfig(DBT_PROJECT_PATH), + profile_config=profile_config, + render_config=render_config, + operator_args=operator_args, + schedule="@daily", + start_date=datetime(2023, 1, 1), + catchup=False, + dag_id="example_watcher_failing_multi_folder", + default_args=default_args, +) diff --git a/dev/failed_dags/example_watcher_failing_tests.py b/dev/failed_dags/example_watcher_failing_tests.py new file mode 100644 index 0000000000..bbdfaa951c --- /dev/null +++ b/dev/failed_dags/example_watcher_failing_tests.py @@ -0,0 +1,57 @@ +""" +DAG to reproduce the watcher bug reported in https://github.com/astronomer/astronomer-cosmos/issues/2430 + +Two models: model_a (succeeds) and model_f (fails - references nonexistent column). +Producer retries with retries=2. +""" + +import os +from datetime import datetime, timedelta +from pathlib import Path + +from cosmos import DbtDag, ExecutionConfig, ProfileConfig, ProjectConfig +from cosmos.constants import ExecutionMode +from cosmos.profiles import PostgresUserPasswordProfileMapping + +DEFAULT_DBT_ROOT_PATH = Path(__file__).parent.parent / "dags/dbt" +DBT_ROOT_PATH = Path(os.getenv("DBT_ROOT_PATH", DEFAULT_DBT_ROOT_PATH)) +DBT_PROJECT_PATH = DBT_ROOT_PATH / "watcher_failing_tests" + +profile_config = ProfileConfig( + profile_name="default", + target_name="dev", + profile_mapping=PostgresUserPasswordProfileMapping( + conn_id="example_conn", + profile_args={"schema": "public"}, + disable_event_tracking=True, + ), +) + +execution_config = ExecutionConfig( + execution_mode=ExecutionMode.WATCHER, +) + +operator_args = { + "install_deps": True, + "execution_timeout": timedelta(seconds=120), +} + +if os.getenv("CI"): + operator_args["trigger_rule"] = "all_success" + +default_args = { + "retries": 2, + "retry_delay": timedelta(seconds=0), +} + +example_watcher_failing_tests = DbtDag( + execution_config=execution_config, + project_config=ProjectConfig(DBT_PROJECT_PATH), + profile_config=profile_config, + operator_args=operator_args, + schedule="@daily", + start_date=datetime(2023, 1, 1), + catchup=False, + dag_id="example_watcher_failing_tests", + default_args=default_args, +) diff --git a/dev/failed_dags/example_watcher_failing_tests_kubernetes.py b/dev/failed_dags/example_watcher_failing_tests_kubernetes.py new file mode 100644 index 0000000000..5520c2515a --- /dev/null +++ b/dev/failed_dags/example_watcher_failing_tests_kubernetes.py @@ -0,0 +1,94 @@ +""" +DAG to test watcher retry behaviour with ExecutionMode.WATCHER_KUBERNETES. + +Uses the watcher_failing_tests dbt project (model_a succeeds, model_f fails). +Producer retries with retries=2 to exercise the XCom backup/restore mechanism. +""" + +import os +from datetime import timedelta +from pathlib import Path + +from airflow.providers.cncf.kubernetes.secret import Secret +from pendulum import datetime + +from cosmos import DbtDag +from cosmos.config import ExecutionConfig, ProfileConfig, ProjectConfig, RenderConfig +from cosmos.constants import ExecutionMode, LoadMode, TestBehavior + +DEFAULT_DBT_ROOT_PATH = Path(__file__).resolve().parent.parent / "dags/dbt" +DBT_ROOT_PATH = Path(os.getenv("DBT_ROOT_PATH", DEFAULT_DBT_ROOT_PATH)) +AIRFLOW_DBT_PROJECT_DIR = DBT_ROOT_PATH / "watcher_failing_tests" + +K8S_PROJECT_DIR = "dags/dbt/watcher_failing_tests" +K8S_DBT_PROFILES_YAML_FILEPATH = Path(K8S_PROJECT_DIR) / "profiles.yml" + +DBT_IMAGE = "dbt-watcher-failing-tests:1.0.0" + +postgres_password_secret = Secret( + deploy_type="env", + deploy_target="POSTGRES_PASSWORD", + secret="postgres-secrets", + key="password", +) + +postgres_host_secret = Secret( + deploy_type="env", + deploy_target="POSTGRES_HOST", + secret="postgres-secrets", + key="host", +) + +operator_args = { + "deferrable": False, + "image": DBT_IMAGE, + "get_logs": True, + "is_delete_operator_pod": False, + "log_events_on_failure": True, + "secrets": [postgres_password_secret, postgres_host_secret], + "env_vars": { + "POSTGRES_DB": "postgres", + "POSTGRES_SCHEMA": "public", + "POSTGRES_USER": "postgres", + }, + "execution_timeout": timedelta(seconds=120), +} + +if os.getenv("CI"): + operator_args["trigger_rule"] = "all_success" + +profile_config = ProfileConfig( + profile_name="postgres_profile", + target_name="dev", + profiles_yml_filepath=K8S_DBT_PROFILES_YAML_FILEPATH, +) + +project_config = ProjectConfig( + project_name="watcher_failing_tests", + manifest_path=AIRFLOW_DBT_PROJECT_DIR / "target/manifest.json", +) + +render_config = RenderConfig( + load_method=LoadMode.DBT_MANIFEST, + test_behavior=TestBehavior.NONE, +) + +default_args = { + "retries": 2, + "retry_delay": timedelta(seconds=0), +} + +example_watcher_failing_tests_kubernetes = DbtDag( + dag_id="example_watcher_failing_tests_kubernetes", + start_date=datetime(2023, 1, 1), + catchup=False, + project_config=project_config, + profile_config=profile_config, + render_config=render_config, + execution_config=ExecutionConfig( + execution_mode=ExecutionMode.WATCHER_KUBERNETES, + dbt_project_path=K8S_PROJECT_DIR, + ), + operator_args=operator_args, + default_args=default_args, +) diff --git a/dev/failed_dags/example_watcher_failing_tests_non_deferrable.py b/dev/failed_dags/example_watcher_failing_tests_non_deferrable.py new file mode 100644 index 0000000000..b8cc2a781c --- /dev/null +++ b/dev/failed_dags/example_watcher_failing_tests_non_deferrable.py @@ -0,0 +1,58 @@ +""" +Airflow DAG to test watcher retry behaviour with deferrable=False. + +Two models: model_a (succeeds) and model_f (fails - references nonexistent column). +Consumer sensors use synchronous poke instead of deferring to the triggerer. +""" + +import os +from datetime import datetime, timedelta +from pathlib import Path + +from cosmos import DbtDag, ExecutionConfig, ProfileConfig, ProjectConfig +from cosmos.constants import ExecutionMode +from cosmos.profiles import PostgresUserPasswordProfileMapping + +DEFAULT_DBT_ROOT_PATH = Path(__file__).parent.parent / "dags/dbt" +DBT_ROOT_PATH = Path(os.getenv("DBT_ROOT_PATH", DEFAULT_DBT_ROOT_PATH)) +DBT_PROJECT_PATH = DBT_ROOT_PATH / "watcher_failing_tests" + +profile_config = ProfileConfig( + profile_name="default", + target_name="dev", + profile_mapping=PostgresUserPasswordProfileMapping( + conn_id="example_conn", + profile_args={"schema": "public"}, + disable_event_tracking=True, + ), +) + +execution_config = ExecutionConfig( + execution_mode=ExecutionMode.WATCHER, +) + +operator_args = { + "install_deps": True, + "deferrable": False, + "execution_timeout": timedelta(seconds=120), +} + +if os.getenv("CI"): + operator_args["trigger_rule"] = "all_success" + +default_args = { + "retries": 2, + "retry_delay": timedelta(seconds=0), +} + +example_watcher_failing_tests_non_deferrable = DbtDag( + execution_config=execution_config, + project_config=ProjectConfig(DBT_PROJECT_PATH), + profile_config=profile_config, + operator_args=operator_args, + schedule="@daily", + start_date=datetime(2023, 1, 1), + catchup=False, + dag_id="example_watcher_failing_tests_non_deferrable", + default_args=default_args, +) diff --git a/dev/failed_dags/example_watcher_failing_tests_subprocess.py b/dev/failed_dags/example_watcher_failing_tests_subprocess.py new file mode 100644 index 0000000000..9f90794a86 --- /dev/null +++ b/dev/failed_dags/example_watcher_failing_tests_subprocess.py @@ -0,0 +1,58 @@ +""" +DAG to test watcher retry behaviour with SUBPROCESS invocation mode. + +Two models: model_a (succeeds) and model_f (fails - references nonexistent column). +Uses InvocationMode.SUBPROCESS instead of DBT_RUNNER. +""" + +import os +from datetime import datetime, timedelta +from pathlib import Path + +from cosmos import DbtDag, ExecutionConfig, ProfileConfig, ProjectConfig +from cosmos.constants import ExecutionMode, InvocationMode +from cosmos.profiles import PostgresUserPasswordProfileMapping + +DEFAULT_DBT_ROOT_PATH = Path(__file__).parent.parent / "dags/dbt" +DBT_ROOT_PATH = Path(os.getenv("DBT_ROOT_PATH", DEFAULT_DBT_ROOT_PATH)) +DBT_PROJECT_PATH = DBT_ROOT_PATH / "watcher_failing_tests" + +profile_config = ProfileConfig( + profile_name="default", + target_name="dev", + profile_mapping=PostgresUserPasswordProfileMapping( + conn_id="example_conn", + profile_args={"schema": "public"}, + disable_event_tracking=True, + ), +) + +execution_config = ExecutionConfig( + execution_mode=ExecutionMode.WATCHER, + invocation_mode=InvocationMode.SUBPROCESS, +) + +operator_args = { + "install_deps": True, + "execution_timeout": timedelta(seconds=120), +} + +if os.getenv("CI"): + operator_args["trigger_rule"] = "all_success" + +default_args = { + "retries": 2, + "retry_delay": timedelta(seconds=0), +} + +example_watcher_failing_tests_subprocess = DbtDag( + execution_config=execution_config, + project_config=ProjectConfig(DBT_PROJECT_PATH), + profile_config=profile_config, + operator_args=operator_args, + schedule="@daily", + start_date=datetime(2023, 1, 1), + catchup=False, + dag_id="example_watcher_failing_tests_subprocess", + default_args=default_args, +) diff --git a/docs/conf.py b/docs/conf.py index d0ef06d2d4..b4f9de10c7 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -102,6 +102,7 @@ "getting_started/operators": "../guides/run_dbt/operators/operators.html", "getting_started/watcher-execution-mode": "../guides/run_dbt/airflow-worker/watcher-execution-mode.html", "getting_started/watcher-kubernetes-execution-mode": "../guides/run_dbt/container/watcher-kubernetes-execution-mode.html", + "optimize_performance/partial-parsing": "../guides/run_dbt/customization/partial-parsing.html", "profiles/AthenaAccessKey": "../reference/profiles/AthenaAccessKey.html", "profiles/ClickhouseUserPassword": "../reference/profiles/ClickhouseUserPassword.html", "profiles/DatabricksOauth": "../reference/profiles/DatabricksOauth.html", diff --git a/docs/guides/run_dbt/airflow-worker/watcher-execution-mode.rst b/docs/guides/run_dbt/airflow-worker/watcher-execution-mode.rst index a412103a1c..0073f801db 100644 --- a/docs/guides/run_dbt/airflow-worker/watcher-execution-mode.rst +++ b/docs/guides/run_dbt/airflow-worker/watcher-execution-mode.rst @@ -224,23 +224,67 @@ If a branch of the DAG fails, users can clear the status of a failed consumer ta **Producer retry behavior** -.. versionadded:: 1.12.2 +.. versionchanged:: 1.14.1 -When the ``DbtProducerWatcherOperator`` is triggered for a retry (try_number > 1), it will not re-run the dbt build command and will succeed. In previous versions of Cosmos, the producer task would fail during retries. -This behavior is designed to support TaskGroup-level retries, as reported in `#2282 `_. +When the ``DbtProducerWatcherOperator`` is triggered for a retry (``try_number > 1``), it raises +``AirflowSkipException`` instead of re-running the ``dbt build`` command. Before skipping, it restores +XCom values from a backup so that consumer sensors can still read model statuses from the first attempt. -**Why this matters:** +**XCom backup and restore:** -- In earlier versions, attempting to retry the producer task would raise an ``AirflowException``, causing the retry to fail immediately. -- Now, the producer gracefully skips execution on retries, logging an informational message explaining that the retry was skipped to avoid running a second ``dbt build``. -- This allows users to retry entire TaskGroups and/or DAGs without the producer task blocking the retry flow. +During execution, each XCom push is incrementally backed up to an Airflow Variable. This ensures that +when the producer fails and Airflow clears XCom entries before the retry, the backed-up values can be +restored. On a successful run, the backup Variable is automatically deleted to avoid stale data +accumulating over time. + +**How consumer retries work:** + +1. The producer runs ``dbt build`` — some models succeed, some fail. +2. The producer task fails, and XCom values are backed up to a Variable. +3. On retry, the producer restores XCom from the Variable and raises ``AirflowSkipException``. +4. Consumer sensors read model statuses from the restored XCom. +5. Consumers for successful models complete immediately. +6. Consumers for failed models detect the error status and raise ``AirflowException``. +7. On their own retry, failed consumers fall back to running dbt individually for their model + (via ``_fallback_to_non_watcher_run``), behaving like ``ExecutionMode.LOCAL``. **Important considerations:** -- Retries are no longer forced to ``0`` by Cosmos, since 1.14.0. Users may configure ``retries`` freely on the producer task. On any retry attempt (``try_number > 1``), the producer gracefully skips execution and returns success — it will not re-run the ``dbt build`` command. This means retrying the producer (or clearing an entire TaskGroup) is safe and will not cause duplicate dbt builds. During the retry of the sensor tasks, they will effectively run the corresponding dbt commands. +- Users may configure ``retries`` freely on the producer task. On any retry attempt (``try_number > 1``), + the producer gracefully skips execution — it will not re-run the ``dbt build`` command. +- Retrying the producer (or clearing an entire TaskGroup) is safe and will not cause duplicate dbt builds. +- During the retry of the sensor tasks, they will effectively run the corresponding dbt commands. The overall retry behavior will be further improved once `#1978 `_ is implemented. +Producer done gateway task (DbtTaskGroup only) +............................................... + +.. versionadded:: 1.14.1 + +When using ``DbtTaskGroup`` in watcher mode, the producer task may be skipped on retry +(via ``AirflowSkipException``). Because the producer is a leaf task inside the ``TaskGroup``, +Airflow's default ``trigger_rule="all_success"`` would cause any tasks downstream of the group +to be skipped as well — even when all consumer tasks succeeded. **This makes +``ExecutionMode.WATCHER`` behave differently from ``ExecutionMode.LOCAL``** when used with +``DbtTaskGroup``. + +To prevent this (`#2594 `_), +Cosmos automatically adds an internal gateway task called +``dbt_producer_watcher_done`` inside the ``DbtTaskGroup``. This task: + +- Sits directly downstream of the producer (``producer >> dbt_producer_watcher_done``) +- Uses ``trigger_rule="none_failed"`` so it succeeds even when the producer is skipped +- Absorbs the producer's skip state, preventing it from propagating to tasks downstream of the group + +This gateway is only added for ``DbtTaskGroup`` — ``DbtDag`` does not need it because it handles +the producer-to-consumer dependency differently (``producer >> consumers`` with +``trigger_rule="always"``). + +.. note:: + The ``dbt_producer_watcher_done`` task is visible in the Airflow UI. Click on it to see a + description of its purpose in the "Doc" tab. + Watcher dbt Execution Queue ........................... @@ -341,7 +385,16 @@ Test behavior By default, the watcher mode runs tests alongside models via the ``dbt build`` command being executed by the producer ``DbtProducerWatcherOperator`` operator. -As a starting point, this execution mode does not support the ``TestBehavior.AFTER_EACH`` behavior, since the tests are not run as individual tasks. Since this is the default ``TestBehavior`` in Cosmos, we are injecting ``EmptyOperator`` as a starting point to ensure a seamless transition to the new mode. +.. versionchanged:: 1.14.0 + +Starting with Cosmos 1.14.0, ``TestBehavior.AFTER_EACH`` is fully supported in ``ExecutionMode.WATCHER``. +Each test task is rendered as a ``DbtTestWatcherOperator`` (a ``DbtConsumerWatcherSensor`` subclass) that watches +the aggregated test results published by the producer via XCom. This means test tasks now behave as real sensors +rather than no-op placeholders. + +In Cosmos versions prior to 1.14.0, ``TestBehavior.AFTER_EACH`` was not supported by the watcher mode because tests +were not run as individual tasks. Since ``TestBehavior.AFTER_EACH`` is the default ``TestBehavior`` in Cosmos, +``EmptyOperator`` tasks were injected as placeholders to ensure a seamless transition to the new mode. The ``TestBehavior.BUILD`` behavior is embedded in the producer ``DbtProducerWatcherOperator`` operator. @@ -370,6 +423,15 @@ This use case is not currently supported by the ``ExecutionMode.WATCHER``, since We have a follow-up ticket to `further investigate this use case `_. +Concurrent DAG runs with ``depends_on_past`` +'''''''''''''''''''''''''''''''''''''''''''' + +When ``depends_on_past=True`` is used together with concurrent DAG runs, a race can occur between consecutive runs where the next run's producer starts while the previous run's consumer fallback is still executing, causing two dbt processes to write to the same models concurrently. + +As a workaround, set ``max_active_runs=1`` on the DAG. + +For details, see `#2596 `_. + Advanced config +++++++++++++++ diff --git a/docs/guides/run_dbt/container/watcher-kubernetes-execution-mode.rst b/docs/guides/run_dbt/container/watcher-kubernetes-execution-mode.rst index 429446ab4d..b43c19ecdd 100644 --- a/docs/guides/run_dbt/container/watcher-kubernetes-execution-mode.rst +++ b/docs/guides/run_dbt/container/watcher-kubernetes-execution-mode.rst @@ -165,7 +165,7 @@ The following limitations from ``ExecutionMode.WATCHER`` also apply to ``Executi * **Individual dbt Operators**: Only ``DbtSeedWatcherKubernetesOperator``, ``DbtSnapshotWatcherKubernetesOperator``, and ``DbtRunWatcherKubernetesOperator`` are implemented. The ``DbtTestWatcherKubernetesOperator`` is currently a placeholder. -* **Test behavior**: The ``TestBehavior.AFTER_EACH`` is not supported. Tests are run as part of the ``dbt build`` command by the producer task. +* **Test behavior**: Unlike ``ExecutionMode.WATCHER`` (which fully supports ``TestBehavior.AFTER_EACH`` since Cosmos 1.14.0), ``ExecutionMode.WATCHER_KUBERNETES`` does not yet support ``TestBehavior.AFTER_EACH``. Tests are run as part of the ``dbt build`` command by the producer task, and test tasks are rendered as ``EmptyOperator`` placeholders. This is tracked in `#1974 `_. * **Source freshness nodes**: The ``dbt build`` command does not run source freshness checks. diff --git a/docs/guides/translate_dbt_to_airflow/selecting-excluding.rst b/docs/guides/translate_dbt_to_airflow/selecting-excluding.rst index fc08065580..c18443cedb 100644 --- a/docs/guides/translate_dbt_to_airflow/selecting-excluding.rst +++ b/docs/guides/translate_dbt_to_airflow/selecting-excluding.rst @@ -221,6 +221,8 @@ Cosmos distinguishes between two types of errors when parsing YAML selectors: - Missing required ``definition`` key These errors indicate malformed YAML structure and will raise a ``CosmosValueError`` immediately when calling ``YamlSelectors.parse()``. + You can modify this behavior to log instead of raising an exception by setting ``AIRFLOW__COSMOS__ENABLE_LAX_SELECTOR_PARSING=1``. + This allows DAGs referencing a manifest with one or more malformed selectors to render as long as the ``selector`` argument is valid. - **Selector Definition Errors** - These are isolated and surfaced when accessing the selector: diff --git a/docs/optimize_performance/memory_optimization.rst b/docs/optimize_performance/memory_optimization.rst index 3ffbaeb058..2e3b07604b 100644 --- a/docs/optimize_performance/memory_optimization.rst +++ b/docs/optimize_performance/memory_optimization.rst @@ -174,7 +174,44 @@ Cosmos provides various configuration options and execution modes to optimize me ------------------------------------------------------------------------------- -6. Enable Task Profiling with Debug Mode +6. Enable Memory-Optimized Imports (Cosmos < 1.14.0 only) ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + +**Applies to**: Cosmos versions **earlier than 1.14.0**. Since Cosmos 1.14.0, ``cosmos/__init__.py`` uses lazy imports by default and this setting is a no-op. If you are on 1.14.0 or newer, skip this section. + +**Impact**: High - Reduces memory footprint both at the DAG processor and at worker nodes. + +**Configuration**: + +.. code-block:: cfg + + # In airflow.cfg + [cosmos] + enable_memory_optimised_imports = True + +.. code-block:: bash + + # Or via environment variable + export AIRFLOW__COSMOS__ENABLE_MEMORY_OPTIMISED_IMPORTS=True + +**What it does**: Disables eager imports in ``cosmos/__init__.py``, preventing unused modules and classes from being loaded into memory. + +**Note**: When enabled, you must use full module paths for importing classes, functions and objects from Cosmos: + +.. code-block:: python + + # Instead of: + from cosmos import DbtDag, ProjectConfig, RenderConfig + + # Use: + from cosmos.airflow.dag import DbtDag + from cosmos.config import ProjectConfig, RenderConfig + +**Default**: ``False`` + +------------------------------------------------------------------------------- + +7. Enable Task Profiling with Debug Mode ++++++++++++++++++++++++++++++++++++++++ **Impact**: Low - Provides visibility into memory usage patterns to help identify optimization opportunities and prevent OOM issues. diff --git a/docs/reference/configs/cosmos-conf.rst b/docs/reference/configs/cosmos-conf.rst index e3e3ae0ba6..507bba8d5e 100644 --- a/docs/reference/configs/cosmos-conf.rst +++ b/docs/reference/configs/cosmos-conf.rst @@ -60,6 +60,15 @@ This page lists all available Airflow configurations that affect ``astronomer-co - Default: ``True`` - Environment Variable: ``AIRFLOW__COSMOS__ENABLE_CACHE_DBT_YAML_SELECTORS`` +.. _enable_lax_selector_parsing: + +`enable_lax_selector_parsing`_: + Enable or disable lax parsing of YAML selectors when using ``LoadMode.DBT_MANIFEST`` with ``RenderConfig.selector``. + Lax parsing instructs the parser to log errors for selectors with malformed YAML structure instead of raising an exception. + + - Default: ``False`` + - Environment Variable: ``AIRFLOW__COSMOS__ENABLE_LAX_SELECTOR_PARSING`` + .. _enable_cache_partial_parse: `enable_cache_partial_parse`_: @@ -295,13 +304,35 @@ This page lists all available Airflow configurations that affect ``astronomer-co .. _enable_memory_optimised_imports: `enable_memory_optimised_imports`_: - (Introduced in Cosmos 1.10.1): This setting is a no-op since Cosmos 1.14.0. All imports in ``cosmos/__init__.py`` - are now lazy by default — modules are only loaded on first access. This provides the same memory optimization - automatically, without requiring users to change their import paths. The setting is still accepted but has no effect. + (Introduced in Cosmos 1.10.1): On Cosmos versions earlier than 1.14.0, eager imports in ``cosmos/__init__.py`` + exposed all Cosmos classes at the top level, which could significantly increase memory usage—even when Cosmos was + just installed but not actively used. This option allowed disabling those eager imports to reduce memory footprint. + When enabled, users had to access Cosmos classes via their full module paths, avoiding the overhead of importing + unused modules and classes. - Default: ``False`` - Environment Variable: ``AIRFLOW__COSMOS__ENABLE_MEMORY_OPTIMISED_IMPORTS`` + .. note:: + Since Cosmos 1.14.0, all imports in ``cosmos/__init__.py`` are lazy by default — modules are only loaded on + first access. This provides the same memory optimization automatically, without requiring users to change their + import paths. On 1.14.0 and newer, this setting is accepted but has no effect. The guidance below is relevant + only for users on Cosmos versions earlier than 1.14.0. + + When enabled, import Cosmos classes via their full module paths: + + .. literalinclude:: ../../../dev/dags/basic_cosmos_dag_full_module_path_imports.py + :language: python + :start-after: [START cosmos_explicit_imports] + :end-before: [END cosmos_explicit_imports] + + as opposed to the following approach you might have when this option is disabled (default): + + .. literalinclude:: ../../../dev/dags/basic_cosmos_dag.py + :language: python + :start-after: [START cosmos_init_imports] + :end-before: [END cosmos_init_imports] + .. _enable_telemetry: `enable_telemetry`_: diff --git a/scripts/test/kubernetes-setup.sh b/scripts/test/kubernetes-setup.sh index 267967ee63..fd35ec9141 100644 --- a/scripts/test/kubernetes-setup.sh +++ b/scripts/test/kubernetes-setup.sh @@ -24,8 +24,12 @@ kubectl apply -f scripts/test/postgres-deployment.yaml # Build the Docker image with tag 'dbt-jaffle-shop:1.0.0' using the specified Dockerfile cd dev && docker build --progress=plain --no-cache -t dbt-jaffle-shop:1.0.0 -f Dockerfile.postgres_profile_docker_k8s . -# Load the Docker image into the local KIND cluster +# Build the Docker image for the watcher failing tests project +docker build --progress=plain --no-cache -t dbt-watcher-failing-tests:1.0.0 -f Dockerfile.watcher_failing_tests_k8s . + +# Load the Docker images into the local KIND cluster kind load docker-image dbt-jaffle-shop:1.0.0 +kind load docker-image dbt-watcher-failing-tests:1.0.0 # Retrieve the name of the PostgreSQL pod using the label selector 'app=postgres' # The output is filtered to get the first pod's name diff --git a/tests/conftest.py b/tests/conftest.py index a52b96e0a9..3670e14b87 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,4 +1,5 @@ import json +import logging from unittest.mock import Mock, patch import pytest @@ -8,6 +9,63 @@ from packaging.version import Version from cosmos.constants import AIRFLOW_VERSION +from cosmos.log import CosmosRichLogger + + +@pytest.fixture(autouse=True, scope="session") +def _cache_airflow_in_process_api(): + """Cache the InProcessExecutionAPI to avoid per-task FastAPI app creation in dag.test(). + + Airflow 3.x's dag.test() creates a new InProcessExecutionAPI — a full FastAPI + application with ASGI middleware, JWT auth, dependency injection, and an async + event loop — for every single task. This adds ~6-8s of overhead per task, + making a 13-task DAG take ~80s instead of ~2s. + + This fixture patches in_process_api_server() to return a cached instance, + so the FastAPI app is created once and reused across all tasks and tests. + """ + if AIRFLOW_VERSION < Version("3.1"): + yield + return + + try: + from airflow.sdk.execution_time import supervisor as supervisor_module + + _original_fn = supervisor_module.in_process_api_server + except (ImportError, AttributeError): + yield + return + + _cached_api = None + + def cached_in_process_api_server(): + nonlocal _cached_api + if _cached_api is None: + _cached_api = _original_fn() + return _cached_api + + supervisor_module.in_process_api_server = cached_in_process_api_server + try: + yield + finally: + supervisor_module.in_process_api_server = _original_fn + + +@pytest.fixture(autouse=True) +def _cleanup_rich_loggers(): + """Replace any CosmosRichLogger instances with standard loggers after each test. + + get_logger() registers CosmosRichLogger instances in Python's global logging cache. + If a cosmos module is imported for the first time while rich_logging=True + (via monkeypatch in test_log.py), its module-level logger becomes a CosmosRichLogger + permanently, polluting subsequent tests with the (astronomer-cosmos) prefix. + """ + yield + manager = logging.Logger.manager + for name in list(manager.loggerDict): + obj = manager.loggerDict[name] + if isinstance(obj, CosmosRichLogger): + obj.__class__ = logging.Logger def make_task_instance(task, **kwargs): diff --git a/tests/dbt/test_graph.py b/tests/dbt/test_graph.py index 8e0c31ff63..6c7d8324c7 100644 --- a/tests/dbt/test_graph.py +++ b/tests/dbt/test_graph.py @@ -52,7 +52,18 @@ SOURCE_RENDERING_BEHAVIOR = SourceRenderingBehavior(os.getenv("SOURCE_RENDERING_BEHAVIOR", "none")) # File and directory names to skip when copying dbt project trees in tests (keeps fixture output deterministic). -_DBT_PROJECT_COPY_IGNORED_FILE_AND_DIR_NAMES = (".user.yml", ".DS_Store", "logs", "target") +_DBT_PROJECT_COPY_IGNORED_FILE_AND_DIR_NAMES = ( + # Mirrors dbt .gitignore entries + "target", + "dbt_packages", + "dbt_internal_packages", + "logs", + # Local/OS artifacts + ".user.yml", + ".DS_Store", + "__pycache__", + "venv", +) if AIRFLOW_VERSION.major >= _AIRFLOW3_MAJOR_VERSION: object_storage_path = "airflow.sdk.ObjectStoragePath" @@ -558,7 +569,9 @@ def test_load_via_manifest_with_selectors_and_missing_definitions(): assert err_info.value.args[0] == f"Selectors not found in manifest file `{project_config.manifest_path}`" -def test_load_via_manifest_with_selectors_and_missing_selector(): +@pytest.mark.parametrize("enable_lax_selector_parsing", [True, False]) +def test_load_via_manifest_with_selectors_and_missing_selector(enable_lax_selector_parsing, monkeypatch): + monkeypatch.setattr("cosmos.settings.enable_lax_selector_parsing", enable_lax_selector_parsing) project_config = ProjectConfig( dbt_project_path=DBT_PROJECTS_ROOT_DIR / DBT_PROJECT_NAME, manifest_path=SAMPLE_MANIFEST_SELECTORS ) @@ -582,7 +595,13 @@ def test_load_via_manifest_with_selectors_and_missing_selector(): with pytest.raises(CosmosLoadDbtException) as err_info: dbt_graph.load_from_dbt_manifest() - assert "Selector `my_selector` not found in parsed YAML selectors" in err_info.value.args[0] + error_message = err_info.value.args[0] + assert "Selector `my_selector` not found in parsed YAML selectors" in error_message + + if enable_lax_selector_parsing: + assert "Check logs" in error_message + else: + assert "Check logs" not in error_message def test_load_via_manifest_with_selectors(): @@ -2346,9 +2365,9 @@ def test_save_dbt_ls_cache(mock_variable_set, mock_datetime, tmp_dbt_project_dir # Different macOS versions have produced different hashes for this directory. The first value below is a # historical macOS-specific hash, while the second matches the Linux hash asserted in the else-branch. We # allow both here so that the test is stable across macOS versions and when macOS hashing matches Linux. - assert hash_dir in ("9d95cbf6529e2ab51fadd6a3f0a3971f", "633a523f295ef0cd496525428815537b") + assert hash_dir in ("9d95cbf6529e2ab51fadd6a3f0a3971f", "5c1aed937708e585054c874ff8f33fd1") else: - assert hash_dir == "633a523f295ef0cd496525428815537b" + assert hash_dir == "5c1aed937708e585054c874ff8f33fd1" @patch("cosmos.dbt.graph.datetime") @@ -2382,15 +2401,47 @@ def test_save_yaml_selectors_cache(mock_variable_set, mock_datetime, tmp_dbt_pro hash_dir, hash_selectors, hash_impl = version.split(",") assert hash_selectors == "43303af03e84e3b51fbfcf598261fae4" - assert hash_impl == "4c93048c66ca45356e1677511447c7ba" + assert hash_impl == "86424c8b70c2e9b6d1f595c7ec9a8291" if sys.platform == "darwin": # Some macOS versions compute a different directory hash than Linux, while others match the Linux behavior. # The first value is the macOS-specific hash; the second value is the Linux hash, which certain macOS versions also produce. # We allow both here to keep the test stable across macOS releases, while non-macOS platforms assert only the Linux hash. - assert hash_dir in ("9d95cbf6529e2ab51fadd6a3f0a3971f", "633a523f295ef0cd496525428815537b") + assert hash_dir in ("9d95cbf6529e2ab51fadd6a3f0a3971f", "5c1aed937708e585054c874ff8f33fd1") else: - assert hash_dir == "633a523f295ef0cd496525428815537b" + assert hash_dir == "5c1aed937708e585054c874ff8f33fd1" + + +@pytest.mark.skipif(AIRFLOW_VERSION.major < _AIRFLOW3_MAJOR_VERSION, reason="AirflowRuntimeError is Airflow 3+ only") +@patch("cosmos.dbt.graph.Variable.set") +def test_save_dbt_ls_cache_variable_set_failure(mock_variable_set, tmp_dbt_project_dir, caplog): + from airflow.sdk.exceptions import AirflowRuntimeError + + mock_variable_set.side_effect = AirflowRuntimeError(MagicMock()) + graph = DbtGraph(cache_identifier="something", project=ProjectConfig(dbt_project_path=tmp_dbt_project_dir)) + with caplog.at_level(logging.WARNING, logger="cosmos.dbt.graph"): + graph.save_dbt_ls_cache("some output") + assert "Failed to save Cosmos dbt ls cache" in caplog.text + assert "cosmos_cache__something" in caplog.text + assert "AIRFLOW__COSMOS__REMOTE_CACHE_DIR" in caplog.text + + +@pytest.mark.skipif(AIRFLOW_VERSION.major < _AIRFLOW3_MAJOR_VERSION, reason="AirflowRuntimeError is Airflow 3+ only") +@patch("cosmos.dbt.graph.Variable.set") +def test_save_yaml_selectors_cache_variable_set_failure(mock_variable_set, tmp_dbt_project_dir, caplog): + from airflow.sdk.exceptions import AirflowRuntimeError + + mock_variable_set.side_effect = AirflowRuntimeError(MagicMock()) + graph = DbtGraph(cache_identifier="something", project=ProjectConfig(dbt_project_path=tmp_dbt_project_dir)) + selectors = YamlSelectors( + {"staging_orders": {"name": "staging_orders", "definition": {"method": "tag", "value": "tag_a"}}}, + {"select": ["tag:tag_a"], "exclude": None}, + ) + with caplog.at_level(logging.WARNING, logger="cosmos.dbt.graph"): + graph.save_yaml_selectors_cache(selectors) + assert "Failed to save Cosmos YAML selectors cache" in caplog.text + assert "cosmos_cache__something" in caplog.text + assert "AIRFLOW__COSMOS__REMOTE_CACHE_DIR" in caplog.text @pytest.mark.integration diff --git a/tests/dbt/test_runner.py b/tests/dbt/test_runner.py index e312dad4f2..90e0d46fbd 100644 --- a/tests/dbt/test_runner.py +++ b/tests/dbt/test_runner.py @@ -226,6 +226,8 @@ class _MockTI: def __init__(self): self.openlineage_events_completes = [] self.store = {} + self.dag_id = "test_dag" + self.task_id = "test_task" def xcom_push(self, key, value, **_): self.store[key] = value @@ -299,7 +301,11 @@ def invoke(self, *args): ) op2.invocation_mode = InvocationMode.DBT_RUNNER - op2.execute(context=mock_context) + with ( + patch("cosmos.operators._watcher.xcom._persist_backup"), + patch("airflow.models.Variable"), + ): + op2.execute(context=mock_context) # Verify: # 1. We have two dbt Runner instances (cached + new with callbacks) diff --git a/tests/dbt/test_selector.py b/tests/dbt/test_selector.py index f28e157cfb..3ff6762524 100644 --- a/tests/dbt/test_selector.py +++ b/tests/dbt/test_selector.py @@ -1552,7 +1552,7 @@ def test_exposure_selector(): ) def test_valid_method_yaml_selectors(selector_name, selector_definition, parsed_selector): selectors = {selector_name: selector_definition} - yaml_selectors = YamlSelectors.parse(selectors) + yaml_selectors = YamlSelectors.parse(selectors, lax_parsing_enabled=False) result = yaml_selectors.get_parsed(selector_name) @@ -1651,7 +1651,7 @@ def test_valid_method_yaml_selectors(selector_name, selector_definition, parsed_ ) def test_valid_graph_operator_yaml_selectors(selector_name, selector_definition, parsed_selector): selectors = {selector_name: selector_definition} - yaml_selectors = YamlSelectors.parse(selectors) + yaml_selectors = YamlSelectors.parse(selectors, lax_parsing_enabled=False) result = yaml_selectors.get_parsed(selector_name) @@ -1736,7 +1736,7 @@ def test_valid_graph_operator_yaml_selectors(selector_name, selector_definition, ) def test_invalid_cosmos_method_yaml_selectors(selector_name, selector_definition, exception_msg): selectors = {selector_name: selector_definition} - yaml_selectors = YamlSelectors.parse(selectors) + yaml_selectors = YamlSelectors.parse(selectors, lax_parsing_enabled=False) with pytest.raises(CosmosValueError) as err_info: _ = yaml_selectors.get_parsed(selector_name) @@ -1962,7 +1962,7 @@ def test_invalid_cosmos_method_yaml_selectors(selector_name, selector_definition ) def test_invalid_dbt_method_yaml_selectors(selector_name, selector_definition, exception_msg): selectors = {selector_name: selector_definition} - yaml_selectors = YamlSelectors.parse(selectors) + yaml_selectors = YamlSelectors.parse(selectors, lax_parsing_enabled=False) with pytest.raises(CosmosValueError) as err_info: _ = yaml_selectors.get_parsed(selector_name) @@ -1992,14 +1992,20 @@ def test_invalid_dbt_method_yaml_selectors(selector_name, selector_definition, e ), ], ) -def test_invalid_yaml_structure_raises_immediately(selector_name, selector_definition, exception_msg): - """Test that structural YAML errors (missing name/definition) raise immediately from parse().""" +@pytest.mark.parametrize("lax_parsing_enabled", [True, False]) +def test_invalid_yaml_structures(selector_name, selector_definition, exception_msg, lax_parsing_enabled, caplog): selectors = {selector_name: selector_definition} - with pytest.raises(CosmosValueError) as err_info: - _ = YamlSelectors.parse(selectors) + if lax_parsing_enabled: + yaml_selectors = YamlSelectors.parse(selectors, lax_parsing_enabled=lax_parsing_enabled) + result = yaml_selectors.get_parsed(selector_name) - assert exception_msg in str(err_info.value) + assert result is None + assert exception_msg in caplog.text + else: + with pytest.raises(CosmosValueError) as err_info: + _ = YamlSelectors.parse(selectors, lax_parsing_enabled=False) + assert exception_msg in str(err_info.value) @pytest.mark.parametrize( @@ -2012,7 +2018,7 @@ def test_invalid_yaml_structure_raises_immediately(selector_name, selector_defin ) def test_invalid_value_type_for_list_keys(selector_name, definition_key, invalid_value, expected_error_fragment): selectors = {selector_name: {"name": selector_name, "definition": {definition_key: invalid_value}}} - yaml_selectors = YamlSelectors.parse(selectors) + yaml_selectors = YamlSelectors.parse(selectors, lax_parsing_enabled=False) with pytest.raises(CosmosValueError) as err_info: _ = yaml_selectors.get_parsed(selector_name) @@ -2039,7 +2045,7 @@ def test_invalid_items_in_selector_lists(selector_name, definition_key, invalid_ } } - yaml_selectors = YamlSelectors.parse(selectors) + yaml_selectors = YamlSelectors.parse(selectors, lax_parsing_enabled=False) with pytest.raises(CosmosValueError) as err_info: _ = yaml_selectors.get_parsed(selector_name) @@ -2068,7 +2074,7 @@ def test_non_string_keys_in_selector_dicts(selector_name, definition_key, non_st } } - yaml_selectors = YamlSelectors.parse(selectors) + yaml_selectors = YamlSelectors.parse(selectors, lax_parsing_enabled=False) with pytest.raises(CosmosValueError) as err_info: _ = yaml_selectors.get_parsed(selector_name) @@ -2092,7 +2098,7 @@ def test_selector_reference_resolves_from_cache(): }, } - yaml_selectors = YamlSelectors.parse(selectors) + yaml_selectors = YamlSelectors.parse(selectors, lax_parsing_enabled=False) base_result = yaml_selectors.get_parsed("base_selector") reference_result = yaml_selectors.get_parsed("reference_selector") diff --git a/tests/operators/_watcher/test_state.py b/tests/operators/_watcher/test_state.py index e027ae3744..2612853d78 100644 --- a/tests/operators/_watcher/test_state.py +++ b/tests/operators/_watcher/test_state.py @@ -2,7 +2,11 @@ from __future__ import annotations +import base64 +import json import logging +import zlib +from unittest.mock import MagicMock, patch import pytest @@ -12,6 +16,15 @@ is_dbt_node_status_skipped, is_dbt_node_status_success, is_dbt_node_status_terminal, + is_producer_task_terminated, + safe_xcom_push, +) +from cosmos.operators._watcher.xcom import ( + _backup_xcom_to_variable, + _delete_xcom_backup_variable, + _init_xcom_backup, + _persist_backup, + _restore_xcom_from_variable, ) @@ -51,6 +64,18 @@ def test_is_dbt_node_status_terminal_false(self, status: str | None): assert is_dbt_node_status_terminal(status) is False +class TestProducerTaskTerminated: + """Tests for is_producer_task_terminated helper.""" + + @pytest.mark.parametrize("state", ["success", "failed", "skipped", "upstream_failed", "removed"]) + def test_terminal_states(self, state: str): + assert is_producer_task_terminated(state) is True + + @pytest.mark.parametrize("state", ["running", "deferred", "queued", "scheduled", "up_for_reschedule", None, ""]) + def test_non_terminal_states(self, state: str | None): + assert is_producer_task_terminated(state) is False + + @pytest.mark.parametrize( "dbt_event,expect_error,status", [ @@ -87,3 +112,155 @@ def test_log_dbt_event(caplog, dbt_event, expect_error, status): if status: assert any(status in msg for msg in messages) + + +class _MockTI: + def __init__(self): + self.store = {} + self.dag_id = "test_dag" + self.task_id = "test_task" + + def xcom_push(self, key, value, **_): + self.store[key] = value + + +class TestInitXcomBackup: + def test_sets_var_key_and_buffer_on_ti(self): + ti = _MockTI() + context = {"ti": ti, "run_id": "manual__2026-01-01"} + + _init_xcom_backup(context) + + assert isinstance(ti._cosmos_xcom_backup_var_key, str) + assert "test_dag" in ti._cosmos_xcom_backup_var_key + assert ti._cosmos_xcom_backup_buffer == {} + + def test_includes_task_group_id_when_present(self): + ti = _MockTI() + ti.task = MagicMock(task_group_id="my_group") + context = {"ti": ti, "run_id": "manual__2026-01-01"} + + _init_xcom_backup(context) + + assert "my_group" in ti._cosmos_xcom_backup_var_key + + +class TestPersistBackup: + @patch("airflow.models.Variable") + def test_writes_compressed_data_to_variable(self, mock_variable): + _persist_backup("my_var_key", {"key1": "value1", "key2": "value2"}) + + mock_variable.set.assert_called_once() + var_key, compressed = mock_variable.set.call_args[0] + assert var_key == "my_var_key" + data = json.loads(zlib.decompress(base64.b64decode(compressed.encode("utf-8"))).decode("utf-8")) + assert data == {"key1": "value1", "key2": "value2"} + + @patch("airflow.models.Variable") + def test_skips_empty_buffer(self, mock_variable): + _persist_backup("my_var_key", {}) + + mock_variable.set.assert_not_called() + + +class TestSafeXcomPushBackup: + @patch("cosmos.operators._watcher.xcom._persist_backup") + def test_accumulates_in_backup_buffer_when_active(self, mock_persist): + ti = _MockTI() + context = {"ti": ti, "run_id": "test_run"} + _init_xcom_backup(context) + + safe_xcom_push(ti, "status_key", "success") + + assert ti.store["status_key"] == "success" + assert ti._cosmos_xcom_backup_buffer["status_key"] == "success" + mock_persist.assert_called_once() + + @patch("cosmos.operators._watcher.xcom._persist_backup") + def test_does_not_backup_without_init(self, mock_persist): + ti = _MockTI() + + safe_xcom_push(ti, "key", "value") + + assert ti.store["key"] == "value" + mock_persist.assert_not_called() + + +class TestBackupXcomToVariable: + @patch("cosmos.operators._watcher.xcom._persist_backup") + def test_flushes_buffer(self, mock_persist): + ti = _MockTI() + context = {"ti": ti, "run_id": "test_run"} + _init_xcom_backup(context) + ti._cosmos_xcom_backup_buffer = {"k": "v"} + + _backup_xcom_to_variable(context) + + mock_persist.assert_called_once_with(ti._cosmos_xcom_backup_var_key, {"k": "v"}) + + @patch("cosmos.operators._watcher.xcom._persist_backup") + def test_noop_without_init(self, mock_persist): + ti = _MockTI() + context = {"ti": ti} + + _backup_xcom_to_variable(context) + + mock_persist.assert_not_called() + + +class TestDeleteXcomBackupVariable: + @patch("airflow.models.Variable") + def test_deletes_variable(self, mock_variable): + ti = _MockTI() + context = {"ti": ti, "run_id": "test_run"} + _init_xcom_backup(context) + + _delete_xcom_backup_variable(context) + + mock_variable.delete.assert_called_once_with(ti._cosmos_xcom_backup_var_key) + + @patch("airflow.models.Variable") + def test_noop_without_init(self, mock_variable): + ti = _MockTI() + context = {"ti": ti} + + _delete_xcom_backup_variable(context) + + mock_variable.delete.assert_not_called() + + @patch("airflow.models.Variable") + def test_ignores_key_error(self, mock_variable): + ti = _MockTI() + context = {"ti": ti, "run_id": "test_run"} + _init_xcom_backup(context) + mock_variable.delete.side_effect = KeyError("not found") + + _delete_xcom_backup_variable(context) # should not raise + + +class TestRestoreXcomFromVariable: + @patch("cosmos.operators._watcher.xcom._persist_backup") + @patch("airflow.models.Variable") + def test_restores_entries(self, mock_variable, mock_persist): + ti = _MockTI() + backup = {"key1": "val1", "key2": "val2"} + compressed = base64.b64encode(zlib.compress(json.dumps(backup).encode("utf-8"))).decode("utf-8") + mock_variable.get.return_value = compressed + context = {"ti": ti, "run_id": "test_run"} + + result = _restore_xcom_from_variable(context) + + assert result is True + assert ti.store["key1"] == "val1" + assert ti.store["key2"] == "val2" + mock_variable.delete.assert_called_once() + + @patch("airflow.models.Variable") + def test_returns_false_when_no_backup(self, mock_variable): + ti = _MockTI() + mock_variable.get.return_value = None + context = {"ti": ti, "run_id": "test_run"} + + result = _restore_xcom_from_variable(context) + + assert result is False diff --git a/tests/operators/_watcher/test_triggerer.py b/tests/operators/_watcher/test_triggerer.py index 05092e776a..702ee2b119 100644 --- a/tests/operators/_watcher/test_triggerer.py +++ b/tests/operators/_watcher/test_triggerer.py @@ -396,3 +396,81 @@ async def test_log_startup_events_waits_and_sleeps(self, mock_terminal, mock_sle # ensure terminal check was called mock_terminal.assert_any_call("running") mock_terminal.assert_any_call("error") + + +class TestBuildSuccessEvent: + def setup_method(self): + self.trigger = WatcherTrigger( + model_unique_id="model.test", + producer_task_id="task_1", + dag_id="dag_1", + run_id="run_123", + map_index=None, + ) + + def test_basic_success(self): + event = self.trigger._build_success_event(compiled_sql=None) + assert event == {"status": "success"} + + def test_with_compiled_sql(self): + event = self.trigger._build_success_event(compiled_sql="SELECT 1") + assert event["compiled_sql"] == "SELECT 1" + + def test_with_outlet_uris(self): + self.trigger._outlet_uris = ["uri://dataset1"] + event = self.trigger._build_success_event(compiled_sql=None) + assert event["outlet_uris"] == ["uri://dataset1"] + + def test_with_compiled_sql_and_outlet_uris(self): + self.trigger._outlet_uris = ["uri://dataset1"] + event = self.trigger._build_success_event(compiled_sql="SELECT 1") + assert event == {"status": "success", "compiled_sql": "SELECT 1", "outlet_uris": ["uri://dataset1"]} + + +@pytest.mark.asyncio +class TestWatcherTriggerProducerSkipped: + + def setup_method(self): + self.trigger = WatcherTrigger( + model_unique_id="model.test", + producer_task_id="task_1", + dag_id="dag_1", + run_id="run_123", + map_index=None, + poke_interval=0.001, + ) + + @patch("cosmos.operators._watcher.triggerer._log_dbt_event") + @patch("cosmos.operators._watcher.triggerer.WatcherTrigger._log_startup_events") + async def test_run_producer_skipped_yields_producer_skipped_event( + self, mock_startup_events, mock_dbt_event, caplog + ): + """Test that when producer is skipped, trigger yields failed with PRODUCER_SKIPPED reason.""" + get_xcom_val_mock = AsyncMock(return_value=None) + get_producer_status_mock = AsyncMock(return_value="skipped") + parse_mock = AsyncMock(return_value=(None, None)) + + caplog.set_level("INFO") + + with ( + patch.object(self.trigger, "get_xcom_val", get_xcom_val_mock), + patch.object(self.trigger, "_get_producer_task_status", get_producer_status_mock), + patch.object(self.trigger, "_parse_dbt_node_status_and_compiled_sql", parse_mock), + ): + events = [] + async for event in self.trigger.run(): + events.append(event) + + assert len(events) == 1 + assert events[0].payload == {"status": "failed", "reason": WatcherEventReason.PRODUCER_SKIPPED} + assert "was skipped" in caplog.text + + @patch("cosmos.operators._watcher.triggerer.WatcherTrigger._get_producer_task_status", new_callable=AsyncMock) + async def test_log_startup_events_returns_on_skipped_producer(self, mock_producer_status): + """Test that _log_startup_events returns early when producer is skipped.""" + mock_producer_status.return_value = "skipped" + self.trigger.get_xcom_val = AsyncMock(return_value=None) + + await self.trigger._log_startup_events() + + mock_producer_status.assert_called_once() diff --git a/tests/operators/_watcher/test_watcher_base.py b/tests/operators/_watcher/test_watcher_base.py index eac6a70e2a..912c3f5ad5 100644 --- a/tests/operators/_watcher/test_watcher_base.py +++ b/tests/operators/_watcher/test_watcher_base.py @@ -1,7 +1,7 @@ -from unittest.mock import Mock, patch +from unittest.mock import MagicMock, Mock, patch import pytest -from airflow.exceptions import AirflowSkipException +from airflow.exceptions import AirflowException, AirflowSkipException from cosmos.operators._watcher.base import BaseConsumerSensor, _process_dbt_log_event from cosmos.operators.local import DbtRunLocalOperator @@ -132,3 +132,59 @@ class SubclassBaseConsumerSensor(BaseConsumerSensor, DbtRunLocalOperator): ): with pytest.raises(AirflowSkipException, match="was skipped by the dbt command"): sensor.poke(context) + + +class TestHandleNoDbtNodeStatus: + """Tests for BaseConsumerSensor._handle_no_dbt_node_status.""" + + def _make_sensor(self): + class SubclassBaseConsumerSensor(BaseConsumerSensor, DbtRunLocalOperator): + something_to_be_implemented = True + + extra_context = {"dbt_node_config": {"unique_id": "model.jaffle_shop.stg_orders"}} + sensor = SubclassBaseConsumerSensor( + task_id="test_sensor", + producer_task_id="dbt_run_local", + profile_config=None, + project_dir="/tmp/sample_project", + extra_context=extra_context, + ) + sensor._get_producer_task_status = MagicMock(return_value=None) + return sensor + + @patch("cosmos.operators._watcher.base.BaseConsumerSensor._fallback_to_non_watcher_run", return_value=True) + def test_producer_failed_with_no_poke_retries_falls_back(self, mock_fallback): + sensor = self._make_sensor() + sensor.poke_retry_number = 0 + context = MagicMock() + + result = sensor._handle_no_dbt_node_status("failed", try_number=1, context=context) + + assert result is True + mock_fallback.assert_called_once_with(1, context) + + def test_producer_failed_with_poke_retries_raises(self): + sensor = self._make_sensor() + sensor.poke_retry_number = 1 + + with pytest.raises(AirflowException, match="dbt build command failed"): + sensor._handle_no_dbt_node_status("failed", try_number=1, context=MagicMock()) + + @patch("cosmos.operators._watcher.base.BaseConsumerSensor._fallback_to_non_watcher_run", return_value=True) + def test_producer_skipped_falls_back(self, mock_fallback): + sensor = self._make_sensor() + context = MagicMock() + + result = sensor._handle_no_dbt_node_status("skipped", try_number=1, context=context) + + assert result is True + mock_fallback.assert_called_once_with(1, context) + + def test_no_status_increments_poke_retry(self): + sensor = self._make_sensor() + sensor.poke_retry_number = 0 + + result = sensor._handle_no_dbt_node_status(None, try_number=1, context=MagicMock()) + + assert result is False + assert sensor.poke_retry_number == 1 diff --git a/tests/operators/test_local.py b/tests/operators/test_local.py index 0fa59adee4..f1a8d31ddb 100644 --- a/tests/operators/test_local.py +++ b/tests/operators/test_local.py @@ -684,7 +684,7 @@ def test_run_operator_dataset_with_airflow_3_and_enabled_dataset_alias_false_fai caplog.set_level(logging.ERROR) caplog.clear() - run_test_dag(dag, expect_success=False) + run_test_dag(dag, expected_dag_state=DagRunState.FAILED) assert "ERROR" in caplog.text assert "To emit datasets with Airflow 3, the setting `enable_dataset_alias` must be True (default)." in caplog.text @@ -1234,12 +1234,71 @@ def test_calculate_openlineage_events_completes_openlineage_errors(mock_processo should_store_compiled_sql=False, ) - dbt_base_operator.calculate_openlineage_events_completes(env={}, project_dir=DBT_PROJ_DIR) + dbt_base_operator.calculate_openlineage_events_completes( + env={}, project_dir=DBT_PROJ_DIR, dbt_command_line=dbt_base_operator.base_cmd + dbt_base_operator.dbt_cmd_flags + ) assert instance.parse.called assert "Unable to parse OpenLineage events" in caplog.text +@patch("cosmos.operators.local.DbtLocalBaseOperator._handle_post_execution") +@patch("cosmos.operators.local.DbtLocalBaseOperator.handle_exception") +@patch("cosmos.operators.local.DbtLocalArtifactProcessor") +@patch("cosmos.operators.local.is_openlineage_common_available", True) +@patch("cosmos.config.ProfileConfig.ensure_profile") +@patch("cosmos.operators.local.DbtLocalBaseOperator.invoke_dbt") +@patch("cosmos.operators.local.DbtLocalBaseOperator._clone_project") +@patch("cosmos.operators.local.tempfile.TemporaryDirectory") +def test_run_command_passes_full_cmd_with_profiles_dir_to_openlineage_processor( + mock_tmp_dir, + mock_clone_project, + mock_invoke_dbt, + mock_ensure_profile, + mock_processor, + mock_handle_exception, + mock_handle_post_execution, + tmp_path, +): + """Tests that run_command forwards the full dbt CLI (including --profiles-dir) to DbtLocalArtifactProcessor.""" + profile_path = tmp_path / "profiles" / "profiles.yml" + mock_tmp_dir.return_value.__enter__.return_value = str(tmp_path) + mock_ensure_profile.return_value.__enter__.return_value = (profile_path, {}) + mock_invoke_dbt.return_value = MagicMock() + + mock_processor_instance = MagicMock() + mock_processor_instance.parse.return_value = MagicMock(completes=[]) + mock_processor.return_value = mock_processor_instance + + # Simulate DbtLocalArtifactProcessor.__init__ accepting dbt_command_line + mock_sig = MagicMock() + mock_sig.parameters = {"dbt_command_line": MagicMock()} + + operator = ConcreteDbtLocalBaseOperator( + profile_config=profile_config, + task_id="my-task", + project_dir="my/dir", + emit_datasets=False, + install_deps=False, + invocation_mode=InvocationMode.SUBPROCESS, + ) + + with patch("cosmos.operators.local.inspect.signature", return_value=mock_sig): + operator.run_command( + cmd=["dbt", "cmd"], + env={}, + context={"run_id": "test_run_id", "task_instance": MagicMock()}, + ) + + mock_processor.assert_called_once() + call_kwargs = mock_processor.call_args.kwargs + assert "dbt_command_line" in call_kwargs + assert "--profiles-dir" in call_kwargs["dbt_command_line"] + # Verify the base command is also present (full_cmd = cmd + flags) + assert "dbt" in call_kwargs["dbt_command_line"] + assert "cmd" in call_kwargs["dbt_command_line"] + + @pytest.mark.parametrize( "operator_class,expected_template", [ diff --git a/tests/operators/test_watcher.py b/tests/operators/test_watcher.py index 62f78cb647..752aba4c8b 100644 --- a/tests/operators/test_watcher.py +++ b/tests/operators/test_watcher.py @@ -7,7 +7,7 @@ from unittest.mock import ANY, MagicMock, Mock, patch import pytest -from airflow.exceptions import AirflowException, TaskDeferred +from airflow.exceptions import AirflowException, AirflowSkipException, TaskDeferred from airflow.utils.state import DagRunState try: @@ -38,10 +38,15 @@ DBT_PROJECT_PATH = Path(__file__).parent.parent.parent / "dev/dags/dbt/jaffle_shop" DBT_PROFILES_YAML_FILEPATH = DBT_PROJECT_PATH / "profiles.yml" MULTI_FOLDER_DBT_PROJ_DIR = Path(__file__).parent.parent.parent / "dev/dags/dbt/multi_folder" +DBT_WATCHER_FAILING_TESTS_PATH = Path(__file__).parent.parent.parent / "dev/dags/dbt/watcher_failing_tests" +DBT_WATCHER_DOWNSTREAM_NOT_SKIPPED_PATH = ( + Path(__file__).parent.parent.parent / "dev/dags/dbt/watcher_downstream_not_skipped" +) DBT_EXECUTABLE_PATH = Path(__file__).parent.parent.parent / "venv-subprocess/bin/dbt" DBT_PROJECT_WITH_EMPTY_MODEL_PATH = Path(__file__).parent.parent / "sample/dbt_project_with_empty_model" + project_config = ProjectConfig( dbt_project_path=DBT_PROJECT_PATH, ) @@ -61,6 +66,8 @@ class _MockTI: def __init__(self) -> None: self.store: dict[str, str] = {} self.try_number = 1 + self.dag_id = "test_dag" + self.task_id = "test_task" def xcom_push(self, key: str, value: str, **_): self.store[key] = value @@ -146,50 +153,81 @@ def test_dbt_producer_log_format_always_json(): assert op.log_format == "json" -def test_dbt_producer_watcher_operator_pushes_completion_status(): - """Test that operator pushes 'completed' status to XCom in both success and failure cases.""" +@patch("airflow.models.Variable") +@patch("cosmos.operators._watcher.xcom._persist_backup") +@patch("cosmos.operators.local.DbtLocalBaseOperator.execute") +def test_dbt_producer_watcher_operator_pushes_completion_status_on_success(mock_execute, mock_persist, mock_variable): + """Test that operator pushes 'completed' status to XCom on success.""" op = DbtProducerWatcherOperator(project_dir=".", profile_config=None) mock_ti = _MockTI() - context = {"ti": mock_ti} + context = {"ti": mock_ti, "run_id": "test_run"} + + op.execute(context=context) + + assert mock_ti.store.get("task_status") == "completed" + mock_execute.assert_called_once() - # Test success case - with patch("cosmos.operators.local.DbtLocalBaseOperator.execute") as mock_execute: + +@patch("airflow.models.Variable") +@patch("cosmos.operators._watcher.xcom._persist_backup") +@patch("cosmos.operators.local.DbtLocalBaseOperator.execute") +def test_dbt_producer_watcher_operator_pushes_completion_status_on_failure(mock_execute, mock_persist, mock_variable): + """Test that operator pushes 'completed' status to XCom even when execution fails.""" + op = DbtProducerWatcherOperator(project_dir=".", profile_config=None) + mock_ti = _MockTI() + context = {"ti": mock_ti, "run_id": "test_run"} + + mock_execute.side_effect = RuntimeError("dbt build failed") + + with pytest.raises(RuntimeError): op.execute(context=context) - # Verify status was pushed - assert mock_ti.store.get("task_status") == "completed" - # Verify parent execute was called - mock_execute.assert_called_once() + assert mock_ti.store.get("task_status") == "completed" + mock_execute.assert_called_once() - # Reset mock and store - mock_ti.store.clear() - # Test failure case - class TestException(Exception): - pass +@patch("cosmos.operators._watcher.xcom._persist_backup") +@patch("cosmos.operators.watcher._delete_xcom_backup_variable") +@patch("cosmos.operators.local.DbtLocalBaseOperator.execute") +def test_dbt_producer_watcher_operator_deletes_backup_on_success(mock_execute, mock_delete, mock_persist): + """Test that the XCom backup Variable is deleted after a successful execution.""" + op = DbtProducerWatcherOperator(project_dir=".", profile_config=None) + mock_ti = _MockTI() + context = {"ti": mock_ti, "run_id": "test_run"} - with patch("cosmos.operators.local.DbtLocalBaseOperator.execute") as mock_execute: - mock_execute.side_effect = TestException("test error") + op.execute(context=context) - with pytest.raises(TestException): - op.execute(context=context) + mock_delete.assert_called_once_with(context) - # Verify completed status was pushed even in failure case - assert mock_ti.store.get("task_status") == "completed" - # Verify parent execute was called - mock_execute.assert_called_once() +@patch("cosmos.operators._watcher.xcom._persist_backup") +@patch("cosmos.operators.watcher._delete_xcom_backup_variable") +@patch("cosmos.operators.watcher._backup_xcom_to_variable") +@patch("cosmos.operators.local.DbtLocalBaseOperator.execute") +def test_dbt_producer_watcher_operator_keeps_backup_on_failure(mock_execute, mock_backup, mock_delete, mock_persist): + """Test that the XCom backup Variable is persisted (not deleted) when execution fails.""" + op = DbtProducerWatcherOperator(project_dir=".", profile_config=None) + mock_ti = _MockTI() + context = {"ti": mock_ti, "run_id": "test_run"} -def test_dbt_producer_watcher_operator_requires_task_instance(): + mock_execute.side_effect = RuntimeError("dbt build failed") + + with pytest.raises(RuntimeError): + op.execute(context=context) + + mock_backup.assert_called_once_with(context) + mock_delete.assert_not_called() + + +@patch("cosmos.operators.local.DbtLocalBaseOperator.execute") +def test_dbt_producer_watcher_operator_requires_task_instance(mock_execute): op = DbtProducerWatcherOperator(project_dir=".", profile_config=None) context: dict[str, object] = {} - with patch("cosmos.operators.local.DbtLocalBaseOperator.execute") as mock_execute: - with pytest.raises(AirflowException) as excinfo: - op.execute(context=context) + with pytest.raises(AirflowException, match="expects a task instance"): + op.execute(context=context) mock_execute.assert_not_called() - assert "expects a task instance" in str(excinfo.value) def test_dbt_consumer_watcher_sensor_execute_complete_model_not_run_logs_message(caplog): @@ -217,20 +255,19 @@ def test_dbt_consumer_watcher_sensor_execute_complete_model_not_run_logs_message assert any("ephemeral model or if the model sql file is empty" in message for message in caplog.messages) -def test_dbt_producer_watcher_operator_skips_retry_attempt(caplog): +@patch("cosmos.operators.watcher._restore_xcom_from_variable") +@patch("cosmos.operators.local.DbtLocalBaseOperator.execute") +def test_dbt_producer_watcher_operator_skips_retry_attempt(mock_execute, mock_restore): op = DbtProducerWatcherOperator(project_dir=".", profile_config=None) ti = _MockTI() ti.try_number = 2 - context = {"ti": ti} + context = {"ti": ti, "run_id": "test_run"} - with patch("cosmos.operators.local.DbtLocalBaseOperator.execute") as mock_execute: - with caplog.at_level(logging.INFO): - result = op.execute(context=context) + with pytest.raises(AirflowSkipException, match="does not support Airflow retries"): + op.execute(context=context) + mock_restore.assert_called_once_with(context) mock_execute.assert_not_called() - assert result is None - assert any("does not support Airflow retries" in message for message in caplog.messages) - assert any("skipping execution" in message for message in caplog.messages) @pytest.mark.parametrize( @@ -238,6 +275,7 @@ def test_dbt_producer_watcher_operator_skips_retry_attempt(caplog): [ ({"status": "success"}, None), ({"status": "success", "reason": WatcherEventReason.NODE_NOT_RUN}, None), + ({"status": "failed", "reason": WatcherEventReason.PRODUCER_SKIPPED}, None), ( {"status": "failed", "reason": WatcherEventReason.NODE_FAILED}, "dbt model 'model.pkg.m' failed. Review the producer task 'dbt_producer_watcher_operator' logs for details.", @@ -248,7 +286,8 @@ def test_dbt_producer_watcher_operator_skips_retry_attempt(caplog): ), ], ) -def test_dbt_consumer_watcher_sensor_execute_complete(event, expected_message): +@patch("cosmos.operators._watcher.base.BaseConsumerSensor._fallback_to_non_watcher_run", return_value=True) +def test_dbt_consumer_watcher_sensor_execute_complete(mock_fallback, event, expected_message): sensor = DbtConsumerWatcherSensor( project_dir=".", profiles_dir=".", @@ -1032,8 +1071,10 @@ def test_poke_failure(self): sensor.poke(context) @patch("cosmos.operators.local.AbstractDbtLocalBase.build_and_run_cmd") - def test_task_retry(self, mock_build_and_run_cmd): + def test_task_retry_fallback_when_producer_terminated(self, mock_build_and_run_cmd): + """On retry, if the producer has already finished, fall back to running the model locally.""" sensor = self.make_sensor() + sensor._get_producer_task_status.return_value = "success" ti = MagicMock() ti.try_number = 2 ti.xcom_pull.return_value = None @@ -1042,6 +1083,22 @@ def test_task_retry(self, mock_build_and_run_cmd): sensor.poke(context) mock_build_and_run_cmd.assert_called_once() + @patch("cosmos.operators._watcher.base.get_xcom_val") + def test_task_retry_keeps_polling_when_producer_still_running(self, mock_get_xcom_val): + """On retry, if the producer is still running, keep polling instead of launching a duplicate dbt run.""" + sensor = self.make_sensor() + sensor._get_producer_task_status.return_value = "running" + ti = MagicMock() + ti.try_number = 2 + # _log_startup_events=None, _get_node_status=None, compiled_sql (skipped), _dbt_event=None + ti.xcom_pull.return_value = None + mock_get_xcom_val.return_value = None + context = self.make_context(ti) + + result = sensor.poke(context) + assert result is False + assert sensor.poke_retry_number == 1 + def test_fallback_to_non_watcher_run(self): sensor = self.make_sensor() ti = MagicMock() @@ -1681,12 +1738,13 @@ def test_dbt_task_group_with_watcher(): # assert outcome.state == DagRunState.SUCCESS # Fortunately, when we trigger the DAG run manually, the weight is being respected and the producer task is being picked up in advance. - assert len(dag_dbt_task_group_watcher.task_dict) == 10 + assert len(dag_dbt_task_group_watcher.task_dict) == 11 tasks_names = [task.task_id for task in dag_dbt_task_group_watcher.topological_sort()] expected_task_names = [ "pre_dbt", "dbt_task_group.dbt_producer_watcher", + "dbt_task_group.dbt_producer_watcher_done", "dbt_task_group.raw_customers_seed", "dbt_task_group.raw_orders_seed", "dbt_task_group.raw_payments_seed", @@ -1710,7 +1768,18 @@ def test_dbt_task_group_with_watcher(): assert isinstance(dag_dbt_task_group_watcher.task_dict["dbt_task_group.customers_run"], DbtRunWatcherOperator) assert isinstance(dag_dbt_task_group_watcher.task_dict["dbt_task_group.orders_run"], DbtRunWatcherOperator) - assert dag_dbt_task_group_watcher.task_dict["dbt_task_group.dbt_producer_watcher"].downstream_task_ids == set() + assert dag_dbt_task_group_watcher.task_dict["dbt_task_group.dbt_producer_watcher"].downstream_task_ids == { + "dbt_task_group.dbt_producer_watcher_done", + } + + done_task = dag_dbt_task_group_watcher.task_dict["dbt_task_group.dbt_producer_watcher_done"] + assert done_task.__class__.__name__ == "EmptyOperator" + try: + from airflow.task.trigger_rule import TriggerRule + except ImportError: + from airflow.utils.trigger_rule import TriggerRule # type: ignore[no-redef] + + assert done_task.trigger_rule == TriggerRule.NONE_FAILED @pytest.mark.integration @@ -1871,6 +1940,135 @@ def test_sensor_and_producer_different_param_values(mock_bigquery_conn): assert task.execution_timeout == timedelta(seconds=1) +@pytest.mark.skipif( + AIRFLOW_VERSION < Version("2.10"), + reason="dag.test() in Airflow 2.9 hangs when a task fails with retries configured", +) +@pytest.mark.integration +def test_dbt_dag_with_watcher_and_failing_model(caplog): + """ + Run a DbtDag using `ExecutionMode.WATCHER` with a project where one model fails. + model_a succeeds, model_f fails (references nonexistent column). + The producer task fails, retries are skipped, and the DAG completes without hanging. + Reproduces https://github.com/astronomer/astronomer-cosmos/issues/2430 + """ + caplog.set_level(logging.DEBUG, logger="cosmos.operators._watcher.base") + + watcher_dag = DbtDag( + project_config=ProjectConfig(dbt_project_path=DBT_WATCHER_FAILING_TESTS_PATH), + profile_config=profile_config, + start_date=datetime(2023, 1, 1), + dag_id="watcher_failing_tests", + execution_config=ExecutionConfig(execution_mode=ExecutionMode.WATCHER), + render_config=RenderConfig(emit_datasets=False, test_behavior=TestBehavior.NONE), + default_args={"retries": 2, "retry_delay": timedelta(seconds=0)}, + operator_args={"trigger_rule": "none_failed", "execution_timeout": timedelta(seconds=120)}, + dagrun_timeout=timedelta(seconds=120), + ) + outcome = new_test_dag(watcher_dag, expected_dag_state=DagRunState.FAILED) + + assert len(watcher_dag.dbt_graph.filtered_nodes) == 2 + assert len(watcher_dag.task_dict) == 3 + tasks_names = [task.task_id for task in watcher_dag.topological_sort()] + assert tasks_names == ["dbt_producer_watcher", "model_a_run", "model_f_run"] + + assert isinstance(watcher_dag.task_dict["dbt_producer_watcher"], DbtProducerWatcherOperator) + assert isinstance(watcher_dag.task_dict["model_a_run"], DbtRunWatcherOperator) + assert isinstance(watcher_dag.task_dict["model_f_run"], DbtRunWatcherOperator) + + # The DAG should complete (not hang) even though model_f fails + assert outcome.state == DagRunState.FAILED + tis = {ti.task_id: ti for ti in outcome.get_task_instances()} + assert tis["dbt_producer_watcher"].try_number == 2 + + assert tis["dbt_producer_watcher"].state == "skipped" + assert tis["model_a_run"].state == "success" + assert tis["model_f_run"].state == "failed" + + dbt_error_message = """Database Error in model model_f (models/model_f.sql)\n column "this_column_does_not_exist_at_all" does not exist\n LINE 1""" + assert dbt_error_message in caplog.text + + +@pytest.mark.skipif( + AIRFLOW_VERSION < Version("2.10") or (Version("3.0.0") <= AIRFLOW_VERSION < Version("3.2.0")), + reason=( + "dag.test() in Airflow 2.9 hangs when a task fails with retries configured. " + "Airflow 3.0 runs tasks inline via _run_raw_task without the task SDK supervisor, " + "so RuntimeTaskInstance.get_task_states raises NameError for SUPERVISOR_COMMS and " + "the watcher sensor cannot detect producer termination on retry. " + "Airflow 3.1.x crashes during task finalization (SetRenderedFields) when retrying " + "tasks inside a DbtTaskGroup via dag.test()." + ), +) +@pytest.mark.integration +def test_dbt_task_group_watcher_gateway_prevents_downstream_skip(caplog): + """ + Verify that the dbt_producer_watcher_done gateway task prevents the producer's skip + from propagating to tasks downstream of the DbtTaskGroup. + """ + import psycopg2 + from airflow import DAG + from airflow.hooks.base import BaseHook + + from cosmos import DbtTaskGroup + + try: + from airflow.providers.standard.operators.empty import EmptyOperator + except ImportError: + from airflow.operators.empty import EmptyOperator + + # Reset the fail_once sequence using credentials from the same Airflow connection + airflow_conn = BaseHook.get_connection("example_conn") + with psycopg2.connect( + host=airflow_conn.host, + port=airflow_conn.port or 5432, + dbname=airflow_conn.schema or "postgres", + user=airflow_conn.login, + password=airflow_conn.password, + ) as conn: + conn.autocommit = True + with conn.cursor() as cur: + cur.execute("DROP SEQUENCE IF EXISTS public._cosmos_fail_once_seq") + + caplog.set_level(logging.DEBUG, logger="cosmos.operators._watcher.base") + + with DAG( + dag_id="watcher_taskgroup_gateway_test", + start_date=datetime(2023, 1, 1), + default_args={"retries": 2, "retry_delay": timedelta(seconds=0)}, + dagrun_timeout=timedelta(seconds=120), + ) as dag: + dbt_group = DbtTaskGroup( + group_id="watcher_downstream_not_skipped", + execution_config=ExecutionConfig(execution_mode=ExecutionMode.WATCHER), + project_config=ProjectConfig(dbt_project_path=DBT_WATCHER_DOWNSTREAM_NOT_SKIPPED_PATH), + profile_config=profile_config, + render_config=RenderConfig(emit_datasets=False, test_behavior=TestBehavior.NONE), + operator_args={"trigger_rule": "none_failed", "execution_timeout": timedelta(seconds=120)}, + ) + + # Force gateway >> root consumers so dag.test() respects execution order. + # Using the gateway (not the producer) as upstream ensures consumers see a + # successful upstream even when the producer is skipped. + # This workaround is only needed for dag.test() — the real scheduler uses priority_weight. + # See https://github.com/apache/airflow/issues/56723 + gateway = dag.task_dict["watcher_downstream_not_skipped.dbt_producer_watcher_done"] + for root_task in dbt_group.get_roots(): + if root_task.task_id.endswith("_done") or root_task.task_id.endswith("dbt_producer_watcher"): + continue + gateway >> root_task + + post_dbt = EmptyOperator(task_id="post_dbt") + dbt_group >> post_dbt + + outcome = new_test_dag(dag, expected_dag_state=DagRunState.SUCCESS) + + tis = {ti.task_id: ti for ti in outcome.get_task_instances()} + # The gateway task absorbs the producer skip — post_dbt runs successfully + assert tis["post_dbt"].state == "success" + assert tis["watcher_downstream_not_skipped.dbt_producer_watcher_done"].state == "success" + + def test_dbt_source_watcher_operator_template_fields(): """Test that DbtSourceWatcherOperator includes model_unique_id as a consumer sensor.""" from cosmos.operators._watcher.base import BaseConsumerSensor @@ -1964,8 +2162,9 @@ def test_poke_reads_correct_xcom_key(self): assert self.TESTS_STATUS_XCOM_KEY in xcom_keys_used def test_fallback_raises_on_retry(self): - """On retry (try_number > 1), the test sensor should raise since test re-execution is not yet supported.""" + """On retry (try_number > 1) with a terminated producer, the test sensor should raise since test re-execution is not yet supported.""" sensor = self.make_sensor() + sensor._get_producer_task_status.return_value = "success" ti = MagicMock() ti.try_number = 2 context = self.make_context(ti) @@ -1973,6 +2172,18 @@ def test_fallback_raises_on_retry(self): with pytest.raises(AirflowException, match="Test re-execution is not yet supported"): sensor.poke(context) + def test_retry_keeps_polling_when_producer_still_running(self): + """On retry, if the producer is still running, the test sensor should keep polling instead of raising.""" + sensor = self.make_sensor() + sensor._get_producer_task_status.return_value = "running" + ti = MagicMock() + ti.try_number = 2 + ti.xcom_pull.return_value = None + context = self.make_context(ti) + + result = sensor.poke(context) + assert result is False + class TestDefaultFreshnessCallback: """Tests for the _default_freshness_callback function.""" @@ -2073,6 +2284,188 @@ def test_excludes_test_nodes(self): assert node_ids == ["model.pkg.m1"] assert status == "skip" + def test_node_with_clean_upstream_not_skipped(self): + """A node that depends on both a stale source and a clean model should not be skipped. + + Graph: stale_src → A ← clean_model + ↓ + C + + A has a clean path via clean_model so A (and therefore C) should run. + """ + from cosmos.constants import DbtResourceType + from cosmos.dbt.graph import DbtNode + + nodes = { + "model.pkg.clean_model": DbtNode( + unique_id="model.pkg.clean_model", + resource_type=DbtResourceType.MODEL, + depends_on=[], + path_base=Path("/tmp"), + original_file_path=Path("models/clean.sql"), + ), + "model.pkg.A": DbtNode( + unique_id="model.pkg.A", + resource_type=DbtResourceType.MODEL, + depends_on=["source.pkg.stale_src", "model.pkg.clean_model"], + path_base=Path("/tmp"), + original_file_path=Path("models/a.sql"), + ), + "model.pkg.C": DbtNode( + unique_id="model.pkg.C", + resource_type=DbtResourceType.MODEL, + depends_on=["model.pkg.A"], + path_base=Path("/tmp"), + original_file_path=Path("models/c.sql"), + ), + } + sources_json = {"results": [{"unique_id": "source.pkg.stale_src", "status": "warn"}]} + node_ids, status = _default_freshness_callback( + context=MagicMock(), dag=None, task_group=None, nodes=nodes, sources_json=sources_json + ) + # A has a clean path via clean_model → neither A nor C should be skipped + assert node_ids == [] + assert status == "skip" + + def test_node_skipped_only_when_all_upstreams_stale(self): + """A node whose every upstream is stale or already skipped must be skipped. + + Graph: stale_src1 → A + stale_src2 → B + A, B → C (both parents stale → C must be skipped) + A → D (only A stale, but A has no clean path → D skipped) + """ + from cosmos.constants import DbtResourceType + from cosmos.dbt.graph import DbtNode + + nodes = { + "model.pkg.A": DbtNode( + unique_id="model.pkg.A", + resource_type=DbtResourceType.MODEL, + depends_on=["source.pkg.stale_src1"], + path_base=Path("/tmp"), + original_file_path=Path("models/a.sql"), + ), + "model.pkg.B": DbtNode( + unique_id="model.pkg.B", + resource_type=DbtResourceType.MODEL, + depends_on=["source.pkg.stale_src2"], + path_base=Path("/tmp"), + original_file_path=Path("models/b.sql"), + ), + "model.pkg.C": DbtNode( + unique_id="model.pkg.C", + resource_type=DbtResourceType.MODEL, + depends_on=["model.pkg.A", "model.pkg.B"], + path_base=Path("/tmp"), + original_file_path=Path("models/c.sql"), + ), + "model.pkg.D": DbtNode( + unique_id="model.pkg.D", + resource_type=DbtResourceType.MODEL, + depends_on=["model.pkg.A"], + path_base=Path("/tmp"), + original_file_path=Path("models/d.sql"), + ), + } + sources_json = { + "results": [ + {"unique_id": "source.pkg.stale_src1", "status": "error"}, + {"unique_id": "source.pkg.stale_src2", "status": "error"}, + ] + } + node_ids, status = _default_freshness_callback( + context=MagicMock(), dag=None, task_group=None, nodes=nodes, sources_json=sources_json + ) + assert set(node_ids) == {"model.pkg.A", "model.pkg.B", "model.pkg.C", "model.pkg.D"} + assert status == "skip" + + def test_already_visited_dependent_not_processed_twice(self): + """A dependent reachable via two stale paths is only processed once. + + Graph: stale_src → A + stale_src → B + A, B → C + + A and B are both direct dependents of stale_src. C depends on both A and B. + When A is processed, C is added to visited. When B is then processed, C is + already in visited → the ``if dependent_id in visited: continue`` branch fires. + All three (A, B, C) must still appear in the skip set. + """ + from cosmos.constants import DbtResourceType + from cosmos.dbt.graph import DbtNode + + nodes = { + "model.pkg.A": DbtNode( + unique_id="model.pkg.A", + resource_type=DbtResourceType.MODEL, + depends_on=["source.pkg.stale_src"], + path_base=Path("/tmp"), + original_file_path=Path("models/a.sql"), + ), + "model.pkg.B": DbtNode( + unique_id="model.pkg.B", + resource_type=DbtResourceType.MODEL, + depends_on=["source.pkg.stale_src"], + path_base=Path("/tmp"), + original_file_path=Path("models/b.sql"), + ), + "model.pkg.C": DbtNode( + unique_id="model.pkg.C", + resource_type=DbtResourceType.MODEL, + depends_on=["model.pkg.A", "model.pkg.B"], + path_base=Path("/tmp"), + original_file_path=Path("models/c.sql"), + ), + } + sources_json = {"results": [{"unique_id": "source.pkg.stale_src", "status": "error"}]} + node_ids, status = _default_freshness_callback( + context=MagicMock(), dag=None, task_group=None, nodes=nodes, sources_json=sources_json + ) + assert set(node_ids) == {"model.pkg.A", "model.pkg.B", "model.pkg.C"} + assert status == "skip" + + def test_dependent_node_missing_from_nodes_is_skipped(self): + """A dependent_id whose node cannot be resolved via ``nodes.get`` is silently ignored. + + This covers the ``if dependent_node is None: continue`` guard. In normal operation the + dependents reverse-map is built from ``nodes.items()`` so every id is present; this test + simulates a lookup returning ``None`` (e.g. a corrupt or trimmed nodes dict) by using a + dict subclass that overrides ``get`` to return ``None`` for the nominated key. + """ + from cosmos.constants import DbtResourceType + from cosmos.dbt.graph import DbtNode + + class _NullOnGet(dict): # type: ignore[type-arg] + """dict that returns None for keys listed in ``_null_keys``.""" + + def __init__(self, null_keys: set, *args, **kwargs): # type: ignore[type-arg] + super().__init__(*args, **kwargs) + self._null_keys = null_keys + + def get(self, key, default=None): # type: ignore[override] + if key in self._null_keys: + return None + return super().get(key, default) + + raw_nodes = { + "model.pkg.A": DbtNode( + unique_id="model.pkg.A", + resource_type=DbtResourceType.MODEL, + depends_on=["source.pkg.stale_src"], + path_base=Path("/tmp"), + original_file_path=Path("models/a.sql"), + ), + } + # nodes.get("model.pkg.A") will return None → the node is silently skipped + nodes = _NullOnGet({"model.pkg.A"}, raw_nodes) + sources_json = {"results": [{"unique_id": "source.pkg.stale_src", "status": "error"}]} + node_ids, status = _default_freshness_callback( + context=MagicMock(), dag=None, task_group=None, nodes=nodes, sources_json=sources_json + ) + assert node_ids == [] + assert status == "skip" + class TestProducerSourceFreshness: """Tests for source freshness methods on DbtProducerWatcherOperator.""" @@ -2133,3 +2526,27 @@ def test_skipped_node_token_noop_when_empty(self): producer._skipped_node_token(context, []) ti.xcom_push.assert_not_called() assert producer.exclude is None + + def test_run_dbt_runner_skips_callback_during_source_freshness(self): + """run_dbt_runner must not register the XCom-pushing callback during the source freshness + pre-check. Registering it would leave a stale entry in _dbt_runner_callbacks that fires + again for every event during the subsequent dbt build, producing duplicate log lines. + """ + producer = self._make_producer(_check_source_freshness=True) + producer._dbt_runner_callbacks = None + + context = MagicMock() + context.get.side_effect = lambda key, default=None: True if key == "_check_source_freshness" else default + + from cosmos.operators.local import DbtLocalBaseOperator + + with patch.object( + DbtLocalBaseOperator, + "run_dbt_runner", + return_value=MagicMock(), + ) as mock_super: + producer.run_dbt_runner(command=["dbt", "source", "freshness"], env={}, cwd="/tmp", context=context) + + # The callback list must remain untouched — no watcher callback appended + assert producer._dbt_runner_callbacks is None + mock_super.assert_called_once() diff --git a/tests/operators/test_watcher_kubernetes_unit.py b/tests/operators/test_watcher_kubernetes_unit.py index 2d5f8855b8..11929093fa 100644 --- a/tests/operators/test_watcher_kubernetes_unit.py +++ b/tests/operators/test_watcher_kubernetes_unit.py @@ -1,10 +1,9 @@ -import logging import os from pathlib import Path from unittest.mock import MagicMock, patch import pytest -from airflow.exceptions import AirflowException +from airflow.exceptions import AirflowException, AirflowSkipException from airflow.providers.cncf.kubernetes import __version__ as airflow_k8s_provider_version from airflow.providers.cncf.kubernetes.secret import Secret from packaging.version import Version @@ -77,8 +76,9 @@ render_config = RenderConfig(load_method=LoadMode.DBT_MANIFEST, test_behavior=TestBehavior.NONE) +@patch("cosmos.operators.watcher_kubernetes._restore_xcom_from_variable") @patch("cosmos.operators.kubernetes.DbtBuildKubernetesOperator.execute") -def test_skips_retry_attempt(mock_execute, caplog): +def test_skips_retry_attempt(mock_execute, mock_restore, caplog): """ Test that the operator skips execution when a retry is attempted (try_number > 1). """ @@ -90,15 +90,61 @@ def test_skips_retry_attempt(mock_execute, caplog): ti = MagicMock() ti.try_number = 2 - context = {"ti": ti} + context = {"ti": ti, "run_id": "test_run"} - with caplog.at_level(logging.INFO): - result = op.execute(context=context) + with pytest.raises(AirflowSkipException, match="does not support Airflow retries"): + op.execute(context=context) + mock_restore.assert_called_once_with(context) mock_execute.assert_not_called() - assert result is None - assert any("does not support Airflow retries" in message for message in caplog.messages) - assert any("skipping execution" in message for message in caplog.messages) + + +@patch("cosmos.operators.watcher_kubernetes._delete_xcom_backup_variable") +@patch("cosmos.operators.watcher_kubernetes._init_xcom_backup") +@patch("cosmos.operators.kubernetes.DbtBuildKubernetesOperator.execute") +def test_deletes_backup_on_success(mock_execute, mock_init, mock_delete): + """Test that the XCom backup Variable is deleted after a successful execution.""" + op = DbtProducerWatcherKubernetesOperator( + project_dir=".", + profile_config=None, + image="dbt-image:latest", + ) + + ti = MagicMock() + ti.try_number = 1 + context = {"ti": ti} + + op.execute(context=context) + + mock_init.assert_called_once_with(context) + mock_delete.assert_called_once_with(context) + mock_execute.assert_called_once() + + +@patch("cosmos.operators.watcher_kubernetes._backup_xcom_to_variable") +@patch("cosmos.operators.watcher_kubernetes._delete_xcom_backup_variable") +@patch("cosmos.operators.watcher_kubernetes._init_xcom_backup") +@patch("cosmos.operators.kubernetes.DbtBuildKubernetesOperator.execute") +def test_keeps_backup_on_failure(mock_execute, mock_init, mock_delete, mock_backup): + """Test that the XCom backup Variable is persisted (not deleted) when execution fails.""" + op = DbtProducerWatcherKubernetesOperator( + project_dir=".", + profile_config=None, + image="dbt-image:latest", + ) + + ti = MagicMock() + ti.try_number = 1 + context = {"ti": ti} + + mock_execute.side_effect = RuntimeError("dbt build failed") + + with pytest.raises(RuntimeError): + op.execute(context=context) + + mock_init.assert_called_once_with(context) + mock_backup.assert_called_once_with(context) + mock_delete.assert_not_called() def test_raises_exception_when_task_instance_missing(): @@ -174,10 +220,11 @@ def test_first_execution_behaves_as_base_consumer_sensor(mock_startup_events): @patch("cosmos.operators.kubernetes.DbtKubernetesBaseOperator.build_and_run_cmd") def test_retry_executes_as_dbt_run_kubernetes_operator(mock_build_and_run_cmd): """ - On retry (try_number > 1), the sensor should fall back to executing - as DbtRunKubernetesOperator by calling build_and_run_cmd. + On retry (try_number > 1) with a terminated producer, the sensor should + fall back to executing as DbtRunKubernetesOperator by calling build_and_run_cmd. """ sensor = make_sensor() + sensor._get_producer_task_status.return_value = "success" ti = MagicMock() ti.try_number = 2 @@ -191,6 +238,28 @@ def test_retry_executes_as_dbt_run_kubernetes_operator(mock_build_and_run_cmd): mock_build_and_run_cmd.assert_called_once() +@patch("cosmos.operators._watcher.base.get_xcom_val") +@patch("cosmos.operators._watcher.base.BaseConsumerSensor._log_startup_events") +def test_retry_keeps_polling_when_producer_still_running(mock_startup_events, mock_get_xcom_val): + """ + On retry (try_number > 1) with the producer still running, the sensor + should keep polling instead of launching a duplicate dbt run. + """ + sensor = make_sensor() + sensor._get_producer_task_status.return_value = "running" + + ti = MagicMock() + ti.try_number = 2 + ti.xcom_pull.return_value = None + mock_get_xcom_val.return_value = None + context = make_context(ti) + + result = sensor.poke(context) + + assert result is False + assert sensor.poke_retry_number == 1 + + class TestCallbacksNormalization: """Tests for the callbacks normalization logic in DbtProducerWatcherKubernetesOperator.""" diff --git a/tests/test_config.py b/tests/test_config.py index 7797deeac1..66240189d8 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -5,7 +5,7 @@ import pytest from cosmos.config import CosmosConfigException, ExecutionConfig, ProfileConfig, ProjectConfig, RenderConfig -from cosmos.constants import ExecutionMode, InvocationMode +from cosmos.constants import ExecutionMode, InvocationMode, SourceRenderingBehavior from cosmos.exceptions import CosmosValueError from cosmos.profiles.athena.access_key import AthenaAccessKeyProfileMapping from cosmos.profiles.postgres.user_pass import PostgresUserPasswordProfileMapping @@ -271,6 +271,13 @@ def test_render_config_env_vars_deprecated(): RenderConfig(env_vars={"VAR": "value"}) +def test_render_config_source_rendering_behavior_none_warns_and_normalizes(): + """Passing None for source_rendering_behavior should warn and default to SourceRenderingBehavior.NONE.""" + with pytest.warns(UserWarning, match="Passing None for source_rendering_behavior is not supported"): + config = RenderConfig(source_rendering_behavior=None) + assert config.source_rendering_behavior == SourceRenderingBehavior.NONE + + @pytest.mark.parametrize( "execution_mode, invocation_mode, expectation", [ diff --git a/tests/test_log.py b/tests/test_log.py index 8cd102de55..275dbd0c86 100644 --- a/tests/test_log.py +++ b/tests/test_log.py @@ -1,16 +1,15 @@ import pytest -import cosmos.log from cosmos.log import CosmosRichLogger, get_logger from cosmos.provider_info import get_provider_info def test_get_logger(monkeypatch): - monkeypatch.setattr(cosmos.log, "rich_logging", False) + monkeypatch.setattr("cosmos.settings.rich_logging", False) standard_logger = get_logger("test-get-logger-example1") assert not isinstance(standard_logger, CosmosRichLogger) - monkeypatch.setattr(cosmos.log, "rich_logging", True) + monkeypatch.setattr("cosmos.settings.rich_logging", True) custom_logger = get_logger("test-get-logger-example2") assert isinstance(custom_logger, CosmosRichLogger) @@ -21,7 +20,7 @@ def test_get_logger(monkeypatch): def test_rich_logging(monkeypatch, caplog): - monkeypatch.setattr(cosmos.log, "rich_logging", False) + monkeypatch.setattr("cosmos.settings.rich_logging", False) standard_logger = get_logger("test-rich-logging-example1") with caplog.at_level("INFO"): @@ -32,7 +31,7 @@ def test_rich_logging(monkeypatch, caplog): assert log_output.count("\n") == 1 caplog.clear() - monkeypatch.setattr(cosmos.log, "rich_logging", True) + monkeypatch.setattr("cosmos.settings.rich_logging", True) custom_logger = get_logger("test-rich-logging-example2") with caplog.at_level("INFO"): custom_logger.info("Hello, world!") @@ -41,6 +40,15 @@ def test_rich_logging(monkeypatch, caplog): assert caplog.text.count("\n") == 1 +def test_rich_logging_none_message(monkeypatch, caplog): + """CosmosRichLogger should not crash when record.msg is None.""" + monkeypatch.setattr("cosmos.settings.rich_logging", True) + logger = get_logger("test-none-msg") + with caplog.at_level("INFO"): + logger.info(None) + assert len(caplog.records) == 1 + + def test_get_provider_info(): provider_info = get_provider_info() assert "cosmos" in provider_info.get("config").keys() diff --git a/tests/utils.py b/tests/utils.py index 70b408079d..7591b5a877 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -26,21 +26,22 @@ log = logging.getLogger(__name__) -def run_dag(dag: DAG, conn_file_path: str | None = None) -> DagRun: - return test_dag(dag=dag, conn_file_path=conn_file_path) +def run_dag( + dag: DAG, + conn_file_path: str | None = None, + expected_dag_state: DagRunState = DagRunState.SUCCESS, +) -> DagRun: + return test_dag(dag=dag, conn_file_path=conn_file_path, expected_dag_state=expected_dag_state) -def check_dag_success(dag_run: DagRun | None, expect_success: bool = True) -> bool: - """Check if a DAG was successful, if that Airflow version allows it.""" +def check_dag_state(dag_run: DagRun | None, expected_dag_state: DagRunState = DagRunState.SUCCESS) -> bool: + """Check if a DAG reached the expected state, if that Airflow version allows it.""" if dag_run is not None: - if expect_success: - return dag_run.state == DagRunState.SUCCESS - else: - return dag_run.state == DagRunState.FAILED + return dag_run.state == expected_dag_state return True -def new_test_dag(dag: DAG) -> DagRun: +def new_test_dag(dag: DAG, expected_dag_state: DagRunState = DagRunState.SUCCESS) -> DagRun: if AIRFLOW_VERSION >= version.Version("3.1"): # Airflow 3.1+ requires DAG to be serialized to database before calling dag.test() # because create_dagrun() checks for DagVersion and DagModel records @@ -69,18 +70,26 @@ def new_test_dag(dag: DAG) -> DagRun: dr = dag.test(logical_date=timezone.utcnow()) else: dr = dag.test() + assert check_dag_state( + dr, expected_dag_state + ), f"Dag {dag.dag_id} did not reach expected state {expected_dag_state}. State: {dr.state}." return dr def test_dag( - dag, conn_file_path: str | None = None, custom_tester: bool = False, expect_success: bool = True + dag, + conn_file_path: str | None = None, + custom_tester: bool = False, + expected_dag_state: DagRunState = DagRunState.SUCCESS, ) -> DagRun: dr = None if custom_tester: dr = test_old_dag(dag, conn_file_path) elif AIRFLOW_VERSION not in (Version("2.10.0"), Version("2.10.1"), Version("2.10.2"), Version("2.11.0")): - dr = new_test_dag(dag) - assert check_dag_success(dr, expect_success), f"Dag {dag.dag_id} did not run successfully. State: {dr.state}. " + dr = new_test_dag(dag, expected_dag_state=expected_dag_state) + assert check_dag_state( + dr, expected_dag_state + ), f"Dag {dag.dag_id} did not reach expected state {expected_dag_state}. State: {dr.state}. " else: # This is a work around until we fix the issue in Airflow: # https://github.com/apache/airflow/issues/42495 @@ -93,10 +102,10 @@ def test_dag( FAILED tests/test_example_dags.py::test_example_dag[user_defined_profile] """ try: - dr = new_test_dag(dag) - assert check_dag_success( - dr, expect_success - ), f"Dag {dag.dag_id} did not run successfully. State: {dr.state}. " + dr = new_test_dag(dag, expected_dag_state=expected_dag_state) + assert check_dag_state( + dr, expected_dag_state + ), f"Dag {dag.dag_id} did not reach expected state {expected_dag_state}. State: {dr.state}. " except sqlalchemy.exc.PendingRollbackError: warnings.warn( "Early versions of Airflow 2.10 and Airflow 2.11 have issues when running the test command with DatasetAlias / Datasets"