diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 9fa593facf..0119de733b 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -4,6 +4,16 @@ Changelog 1.14.0a6 (2026-03-25) --------------------- +Breaking Changes + +* ``ExecutionMode.WATCHER``: The per-node ``*_status`` XCom value is now a dict (``{"status": "", "outlet_uris": [...]}``) instead of a plain string. Any custom code that reads these internal XCom keys directly will need to be updated. by @pankajkoti in #2507 + +Enhancements + +* Move dataset emission for ``ExecutionMode.WATCHER`` from producer to consumer sensors by @pankajkoti in #2507 + +Bug Fixes + * Fix ``cosmos_debug_max_memory_mb`` XCom not pushed in Watcher sensor tasks by @tatiana in #2503 1.13.1 (2026-02-25) diff --git a/cosmos/airflow/graph.py b/cosmos/airflow/graph.py index b29c0e2b97..73de6c5748 100644 --- a/cosmos/airflow/graph.py +++ b/cosmos/airflow/graph.py @@ -685,6 +685,8 @@ def _add_watcher_producer_task( The producer task is the task that will be used to produce the events for the watcher execution mode. """ producer_task_args = task_args.copy() + # Producer should not emit datasets — consumer tasks handle their own emission + producer_task_args["emit_datasets"] = False if tests_per_model is not None: producer_task_args["tests_per_model"] = tests_per_model if render_config is not None and render_config.source_rendering_behavior != SourceRenderingBehavior.NONE: diff --git a/cosmos/dataset.py b/cosmos/dataset.py index 7bc7a14c2e..29618fd113 100644 --- a/cosmos/dataset.py +++ b/cosmos/dataset.py @@ -1,8 +1,20 @@ from __future__ import annotations -from typing import TYPE_CHECKING +import json +import urllib.parse +from pathlib import Path +from typing import TYPE_CHECKING, Any + +import yaml +from packaging.version import Version + +from cosmos import settings +from cosmos.constants import AIRFLOW_VERSION +from cosmos.log import get_logger if TYPE_CHECKING: + from cosmos.config import ProfileConfig + try: from airflow.sdk import DAG # type: ignore[assignment] @@ -12,6 +24,74 @@ from airflow import DAG # type: ignore[assignment] from airflow.utils.task_group import TaskGroup +logger = get_logger(__name__) + + +def _resolve_snowflake_account(account: str) -> str: + """Normalize a Snowflake account name, delegating to the OL helper when available.""" + try: + from openlineage.common.provider.snowflake import fix_account_name + + return str(fix_account_name(account)) + except ImportError: + return account + + +def _resolve_spark_namespace(profile: dict[str, Any]) -> str: + """Derive the Spark namespace, matching OL's port-by-method logic.""" + host = profile.get("host", "localhost") + if "port" in profile: + port = profile["port"] + elif profile.get("method") in ("http", "odbc"): + port = 443 + elif profile.get("method") == "thrift": + port = 10001 + else: + port = 10000 + return f"spark://{host}:{port}" + + +def _resolve_glue_namespace(profile: dict[str, Any]) -> str: + """Derive the AWS Glue namespace, matching OL's ARN logic. + + Returns an empty string when the account ID cannot be determined, + which causes ``get_dataset_namespace`` to treat it as a failed resolution. + """ + region = profile.get("region", "us-east-1") + if "account_id" in profile and profile["account_id"]: + return f"arn:aws:glue:{region}:{profile['account_id']}" + role_arn = profile.get("role_arn") + if isinstance(role_arn, str) and role_arn: + parts = role_arn.split(":") + if len(parts) > 4 and parts[4]: + return f"arn:aws:glue:{region}:{parts[4]}" + return "" + + +# Mapping of dbt adapter type to a callable that derives the OL-compatible dataset namespace +# from the profile dict. Each callable receives the profile dict and returns the namespace string. +# +# Only adapters supported by the OpenLineage dbt integration are included so that URIs produced +# by WATCHER mode match those from LOCAL mode (which delegates to the OL library). +# +# Source: https://openlineage.io/docs/spec/naming/ +# Reference impl: openlineage.common.provider.dbt.processor.DbtArtifactProcessor.extract_namespace +_ADAPTER_NAMESPACE_RESOLVERS: dict[str, Any] = { + "postgres": lambda p: f"postgres://{p.get('host', 'localhost')}:{p.get('port', 5432)}", + "redshift": lambda p: f"redshift://{p.get('host', 'localhost')}:{p.get('port', 5439)}", + "bigquery": lambda _: "bigquery", + "snowflake": lambda p: f"snowflake://{_resolve_snowflake_account(p.get('account', ''))}", + "databricks": lambda p: f"databricks://{p.get('host', '')}", + "spark": _resolve_spark_namespace, + "trino": lambda p: f"trino://{p.get('host', 'localhost')}:{p.get('port', 8080)}", + "clickhouse": lambda p: f"clickhouse://{p.get('host', 'localhost')}:{p.get('port', 8123)}", + "sqlserver": lambda p: f"mssql://{p.get('server', 'localhost')}:{p.get('port', 1433)}", + "dremio": lambda p: f"dremio://{p.get('software_host', 'localhost')}:{p.get('port', 443)}", + "athena": lambda p: f"awsathena://athena.{p.get('region_name', 'us-east-1')}.amazonaws.com", + "glue": _resolve_glue_namespace, + "duckdb": lambda p: f"duckdb://{p.get('path', '')}", +} + def get_dataset_alias_name(dag: DAG | None, task_group: TaskGroup | None, task_id: str) -> str: """ @@ -40,3 +120,161 @@ def get_dataset_alias_name(dag: DAG | None, task_group: TaskGroup | None, task_i identifiers_list.append(task_id.split(".")[-1]) return "__".join(identifiers_list) + + +def _get_profile_dict(profile_config: ProfileConfig) -> tuple[str, dict[str, Any]]: + """ + Extract the adapter type and profile dict from a ProfileConfig. + + Supports both profile_mapping and profiles_yml_filepath modes. + + :returns: Tuple of (adapter_type, profile_dict) + """ + if profile_config.profile_mapping: + adapter_type = profile_config.profile_mapping.dbt_profile_type + profile_dict = profile_config.profile_mapping.profile + return adapter_type, profile_dict + + if profile_config.profiles_yml_filepath: + with open(profile_config.profiles_yml_filepath) as f: + profiles = yaml.safe_load(f) + target = profiles[profile_config.profile_name]["outputs"][profile_config.target_name] + adapter_type = target.get("type", "") + return adapter_type, target + + return "", {} + + +def get_dataset_namespace(profile_config: ProfileConfig) -> str | None: + """ + Derive an OpenLineage-compatible dataset namespace from a Cosmos ProfileConfig. + + The namespace is adapter-specific: e.g. ``postgres://host:5432``, ``bigquery``, + ``snowflake://account``. + + If the adapter type is recognised in ``_ADAPTER_NAMESPACE_RESOLVERS``, the + corresponding resolver builds the namespace from connection details in the + profile dict. + + Returns ``None`` if the namespace cannot be determined (e.g. missing or + invalid profile configuration, or an unsupported adapter type). When + ``None`` is returned, dataset emission is skipped for the run. + """ + try: + adapter_type, profile_dict = _get_profile_dict(profile_config) + except (AttributeError, KeyError, TypeError, OSError, yaml.YAMLError): + logger.debug("Unable to extract profile info for dataset namespace derivation", exc_info=True) + return None + + if not adapter_type: + return None + + resolver = _ADAPTER_NAMESPACE_RESOLVERS.get(adapter_type) + if resolver: + namespace = str(resolver(profile_dict)) + # Reject empty or scheme-only namespaces (e.g. "snowflake://", "databricks://") + # that result from missing required profile fields. + if not namespace or namespace.endswith("://"): + logger.debug( + "Resolved namespace '%s' for adapter '%s' is incomplete; skipping dataset emission.", + namespace, + adapter_type, + ) + return None + return namespace + + # Unknown adapters: return None so dataset emission is skipped. + # Only adapters supported by the OpenLineage dbt integration are in + # _ADAPTER_NAMESPACE_RESOLVERS; emitting URIs for unsupported adapters + # would produce URIs that don't match LOCAL/VIRTUALENV behavior. + logger.debug( + "Adapter type '%s' is not supported for dataset namespace resolution; skipping dataset emission.", adapter_type + ) + return None + + +def construct_dataset_uri(namespace: str, name: str) -> str: + """ + Construct an Airflow Asset/Dataset URI from an OL-compatible namespace and + dataset name. + + On Airflow 2, dots in the name are preserved (``namespace/db.schema.table``). + On Airflow 3 (or when ``use_dataset_airflow3_uri_standard`` is enabled), + dots are replaced with slashes (``namespace/db/schema/table``). + + :param namespace: The OL-compatible dataset namespace (e.g. ``postgres://host:5432``). + :param name: The dot-delimited dataset name (e.g. ``database.schema.table``). + """ + airflow_2_uri = namespace + "/" + urllib.parse.quote(name) + airflow_3_uri = namespace + "/" + urllib.parse.quote(name).replace(".", "/") + + if AIRFLOW_VERSION < Version("3.0.0"): + if settings.use_dataset_airflow3_uri_standard: + return airflow_3_uri + logger.warning( + "Airflow 3.0.0 Asset (Dataset) URIs validation rules changed and OpenLineage URIs " + "(standard used by Cosmos) will no longer be valid. " + "Therefore, if using Cosmos with Airflow 3, the Airflow Dataset URIs will be changed to <%s>. " + "Previously, with Airflow 2.x, the URI was <%s>. " + "If you want to use the Airflow 3 URI standard while still using Airflow 2, please set: " + "export AIRFLOW__COSMOS__USE_DATASET_AIRFLOW3_URI_STANDARD=1 " + "Remember to update any DAGs that are scheduled using this dataset.", + airflow_3_uri, + airflow_2_uri, + ) + return airflow_2_uri + + logger.warning( + "Airflow 3.0.0 Asset (Dataset) URIs validation rules changed and OpenLineage URIs " + "(standard used by Cosmos) are no longer accepted. " + "Therefore, if using Cosmos with Airflow 3, the Airflow Asset (Dataset) URI is now <%s>. " + "Before, with Airflow 2.x, the URI used to be <%s>. " + "Please, change any DAGs that were scheduled using the old standard to the new one.", + airflow_3_uri, + airflow_2_uri, + ) + return airflow_3_uri + + +def compute_model_outlet_uris(manifest_path: str | Path, namespace: str) -> dict[str, list[str]]: + """ + Read a dbt manifest and compute per-model outlet URIs. + + :param manifest_path: Path to the ``manifest.json`` file. + :param namespace: The OL-compatible dataset namespace (e.g. ``postgres://host:5432``). + :returns: Mapping of ``{unique_id: [uri]}`` for model, seed, and snapshot nodes. + """ + _RESOURCE_TYPES_WITH_DATASETS = {"model", "seed", "snapshot"} + + # The manifest may not exist if dbt failed before completing compilation, + # or if the temp project directory was cleaned up before this function ran. + # In those cases we gracefully return an empty dict — consumers simply + # won't emit datasets for the affected run. + try: + with open(manifest_path) as f: + manifest = json.load(f) + except (OSError, json.JSONDecodeError): + logger.warning( + "Unable to read manifest at %s for dataset URI computation. Dataset emission will be skipped.", + manifest_path, + exc_info=True, + ) + return {} + + result: dict[str, list[str]] = {} + for unique_id, node in manifest.get("nodes", {}).items(): + resource_type = node.get("resource_type", "") + if resource_type not in _RESOURCE_TYPES_WITH_DATASETS: + continue + + database = node.get("database", "") + schema = node.get("schema", "") + alias = node.get("alias") or node.get("name", "") + + if not all([database, schema, alias]): + continue + + uri = construct_dataset_uri(namespace, f"{database}.{schema}.{alias}") + result[unique_id] = [uri] + + return result diff --git a/cosmos/operators/_watcher/base.py b/cosmos/operators/_watcher/base.py index 7a7ebc3cce..aa47e81291 100644 --- a/cosmos/operators/_watcher/base.py +++ b/cosmos/operators/_watcher/base.py @@ -2,6 +2,7 @@ import json import logging +import threading from pathlib import Path from typing import Any @@ -203,12 +204,47 @@ def _log_dbt_msg(log_line: dict[str, Any]) -> None: logger.log(getattr(logging, level, logging.INFO), msg) +_model_outlet_uris_lock = threading.Lock() + + +_MODEL_OUTLET_URIS_ATTEMPTED_KEY = "__attempted__" + + +def _ensure_subprocess_model_outlet_uris( + model_outlet_uris: dict[str, list[str]] | None, + dataset_namespace: str | None, + project_dir: str | None, +) -> None: + """Lazily populate model_outlet_uris from the manifest when first needed. + + Thread-safe: in DBT_RUNNER mode the log parser callback can be invoked from + multiple dbt threads. The lock prevents duplicate manifest reads and ensures + the dict is fully populated before any reader accesses it. + + A sentinel key is set after the attempt completes so that an empty manifest + result does not cause repeated filesystem reads. + """ + if model_outlet_uris is None or not dataset_namespace or not project_dir: + return + with _model_outlet_uris_lock: + if _MODEL_OUTLET_URIS_ATTEMPTED_KEY in model_outlet_uris: + return + from cosmos.dataset import compute_model_outlet_uris + + manifest_path = Path(project_dir) / "target" / "manifest.json" + if manifest_path.exists(): + model_outlet_uris.update(compute_model_outlet_uris(manifest_path, dataset_namespace)) + model_outlet_uris[_MODEL_OUTLET_URIS_ATTEMPTED_KEY] = [] # type: ignore[assignment] + + def store_dbt_resource_status_from_log( line: str, extra_kwargs: Any, *, tests_per_model: dict[str, list[str]] | None = None, test_results_per_model: dict[str, list[str]] | None = None, + model_outlet_uris: dict[str, list[str]] | None = None, + dataset_namespace: str | None = None, ) -> None: """ Parses a single line from dbt JSON logs and stores node status to Airflow XCom. @@ -226,6 +262,9 @@ def store_dbt_resource_status_from_log( tests, collects the terminal statuses of those tests as they finish. Keyed by model unique_id, values are lists of test statuses (e.g. ``["pass", "pass"]``). Mutated in place by this function. + :param model_outlet_uris: Mutable dict mapping unique_id to outlet URIs. + Populated lazily from the manifest on first terminal status detection. + :param dataset_namespace: The OL-compatible dataset namespace for URI construction. """ try: log_line = json.loads(line) @@ -265,8 +304,22 @@ def store_dbt_resource_status_from_log( unique_id, dbt_node_status, tests_per_model, test_results_per_model, context["ti"] ) else: + # Lazily populate per-model outlet URIs from the manifest, but only for + # resource types that can emit datasets (models/seeds/snapshots). + outlet_uris: list[str] = [] + if dbt_node_resource_type in {"model", "seed", "snapshot"}: + project_dir = extra_kwargs.get("project_dir") + _ensure_subprocess_model_outlet_uris(model_outlet_uris, dataset_namespace, project_dir) + outlet_uris = model_outlet_uris.get(unique_id, []) if model_outlet_uris else [] + + status_value: dict[str, Any] | str = { + "status": dbt_node_status, + "outlet_uris": outlet_uris, + } safe_xcom_push( - task_instance=context["ti"], key=f"{unique_id.replace('.', '__')}_status", value=dbt_node_status + task_instance=context["ti"], + key=f"{unique_id.replace('.', '__')}_status", + value=status_value, ) # Extract and push compiled_sql for models (centralised for both subprocess and node-event) @@ -449,7 +502,7 @@ def execute(self, context: Context, **kwargs: Any) -> None: else: self._execute_core(context) - def execute_complete(self, context: Context, event: dict[str, str]) -> None: + def execute_complete(self, context: Context, event: dict[str, Any]) -> None: status = event.get("status") reason = event.get("reason") @@ -505,13 +558,20 @@ def _get_node_status(self, ti: Any, context: Context) -> Any: For test sensors, reads the aggregated ``_tests_status`` key. For model sensors, reads the per-model ``*_status`` key (same for both - SUBPROCESS and DBT_RUNNER invocation modes). + SUBPROCESS and DBT_RUNNER invocation modes). The value is always a dict + with ``status`` and ``outlet_uris`` keys. + + Side effect: stores outlet URIs on ``self._outlet_uris`` for later + dataset emission. """ if self.is_test_sensor: xcom_key = get_tests_status_xcom_key(self.model_unique_id) - else: - xcom_key = f"{self.model_unique_id.replace('.', '__')}_status" - return get_xcom_val(ti, self.producer_task_id, xcom_key) + return get_xcom_val(ti, self.producer_task_id, xcom_key) + xcom_val = get_xcom_val(ti, self.producer_task_id, f"{self.model_unique_id.replace('.', '__')}_status") + if xcom_val is None: + return None + self._outlet_uris = xcom_val.get("outlet_uris", []) + return xcom_val.get("status") def _cache_compiled_sql(self, ti: Any, context: Context) -> None: """Pull compiled_sql from XCom and cache it on the sensor instance.""" diff --git a/cosmos/operators/_watcher/triggerer.py b/cosmos/operators/_watcher/triggerer.py index 0d71191491..c104a3c334 100644 --- a/cosmos/operators/_watcher/triggerer.py +++ b/cosmos/operators/_watcher/triggerer.py @@ -118,8 +118,17 @@ async def get_xcom_val(self, key: str) -> Any | None: return await self.get_xcom_val_af3(key) async def _get_node_status(self) -> Any | None: + """Return the dbt node status from XCom. + + The XCom value is always a dict with ``status`` and ``outlet_uris`` keys. + Stores outlet URIs on ``self._outlet_uris`` for later dataset emission. + """ status_key = f"{self.model_unique_id.replace('.', '__')}_status" - return await self.get_xcom_val(status_key) + xcom_val = await self.get_xcom_val(status_key) + if xcom_val is None: + return None + self._outlet_uris = xcom_val.get("outlet_uris", []) + return xcom_val.get("status") async def _parse_dbt_node_status_and_compiled_sql(self) -> tuple[str | None, str | None]: """ @@ -216,6 +225,10 @@ async def run(self) -> AsyncIterator[TriggerEvent]: event_data: dict[str, Any] = {"status": EventStatus.SUCCESS} if compiled_sql: event_data["compiled_sql"] = compiled_sql + # Pass outlet URIs through TriggerEvent so consumer can emit datasets + outlet_uris = getattr(self, "_outlet_uris", []) + if outlet_uris: + event_data["outlet_uris"] = outlet_uris yield TriggerEvent(event_data) # type: ignore[no-untyped-call] return elif is_dbt_node_status_skipped(dbt_node_status): diff --git a/cosmos/operators/local.py b/cosmos/operators/local.py index c0a4f2bcda..b77c112099 100644 --- a/cosmos/operators/local.py +++ b/cosmos/operators/local.py @@ -6,7 +6,6 @@ import os import tempfile import time -import urllib.parse import warnings import zlib from abc import ABC, abstractmethod @@ -50,7 +49,7 @@ FILE_SCHEME_AIRFLOW_DEFAULT_CONN_ID_MAP, InvocationMode, ) -from cosmos.dataset import get_dataset_alias_name +from cosmos.dataset import construct_dataset_uri, get_dataset_alias_name from cosmos.dbt.project import ( copy_dbt_packages, copy_manifest_file_if_exists, @@ -75,7 +74,6 @@ if TYPE_CHECKING: # pragma: no cover - import openlineage # pragma: no cover from dbt.cli.main import dbtRunner, dbtRunnerResult try: # pragma: no cover @@ -136,7 +134,6 @@ def _read_target_sources_json(project_root: Path) -> dict[str, Any] | None: # The following is related to the ability of Cosmos parsing dbt artifacts and generating OpenLineage URIs # It is used for emitting Airflow assets and not necessarily OpenLineage events try: - import openlineage from openlineage.common.provider.dbt.local import DbtLocalArtifactProcessor is_openlineage_common_available = True @@ -747,38 +744,6 @@ def calculate_openlineage_events_completes( except (FileNotFoundError, NotImplementedError, ValueError, KeyError, jinja2.exceptions.UndefinedError): logger.debug("Unable to parse OpenLineage events", stack_info=True) - @staticmethod - def _create_asset_uri(openlineage_event: openlineage.client.generated.base.OutputDataset) -> str: - """ - Create the Airflow Asset or Dataset UIR given an OpenLineage event. - """ - airflow_2_uri = str(openlineage_event.namespace + "/" + urllib.parse.quote(openlineage_event.name)) - airflow_3_uri = str( - openlineage_event.namespace + "/" + urllib.parse.quote(openlineage_event.name).replace(".", "/") - ) - if AIRFLOW_VERSION < Version("3.0.0"): - if settings.use_dataset_airflow3_uri_standard: - dataset_uri = airflow_3_uri - else: - logger.warning(f""" - Airflow 3.0.0 Asset (Dataset) URIs validation rules changed and OpenLineage URIs (standard used by Cosmos) will no longer be valid. - Therefore, if using Cosmos with Airflow 3, the Airflow Dataset URIs will be changed to <{airflow_3_uri}>. - Previously, with Airflow 2.x, the URI was <{airflow_2_uri}>. - If you want to use the Airflow 3 URI standard while still using Airflow 2, please, set: - export AIRFLOW__COSMOS__USE_DATASET_AIRFLOW3_URI_STANDARD=1 - Remember to update any DAGs that are scheduled using this dataset. - """) - dataset_uri = airflow_2_uri - else: - logger.warning(f""" - Airflow 3.0.0 Asset (Dataset) URIs validation rules changed and OpenLineage URIs (standard used by Cosmos) are no longer accepted. - Therefore, if using Cosmos with Airflow 3, the Airflow Asset (Dataset) URI is now <{airflow_3_uri}>. - Before, with Airflow 2.x, the URI used to be <{airflow_2_uri}>. - Please, change any DAGs that were scheduled using the old standard to the new one. - """) - dataset_uri = airflow_3_uri - return dataset_uri - def get_datasets(self, source: Literal["inputs", "outputs"]) -> list[Asset]: """ Use openlineage-integration-common to extract lineage events from the artifacts generated after running the dbt @@ -793,7 +758,7 @@ def get_datasets(self, source: Literal["inputs", "outputs"]) -> list[Asset]: for completed in self.openlineage_events_completes: for output in getattr(completed, source): - dataset_uri = self._create_asset_uri(output) + dataset_uri = construct_dataset_uri(output.namespace, output.name) uris.append(dataset_uri) logger.debug("URIs to be converted to Asset: %s", uris) diff --git a/cosmos/operators/watcher.py b/cosmos/operators/watcher.py index 30540a6591..b007a4f433 100644 --- a/cosmos/operators/watcher.py +++ b/cosmos/operators/watcher.py @@ -20,6 +20,7 @@ WATCHER_TASK_WEIGHT_RULE, DbtResourceType, ) +from cosmos.dataset import get_dataset_namespace from cosmos.dbt.graph import DbtNode from cosmos.log import get_logger from cosmos.operators._watcher import safe_xcom_push @@ -167,12 +168,21 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(task_id=task_id, *args, **kwargs) self.log_format = "json" + # Mutable dict populated lazily from the manifest; shared with the log parser. + self._dataset_namespace: str | None = None + self._model_outlet_uris: dict[str, list[str]] = {} + + def _handle_datasets(self, context: Context) -> None: + """No-op override: consumer tasks handle their own dataset emission in WATCHER mode.""" + def _make_parse_callable(self) -> Callable[[str, Any], None]: """Returns store_dbt_resource_status_from_log with the operator's test maps pre-bound.""" return functools.partial( store_dbt_resource_status_from_log, tests_per_model=self.tests_per_model, test_results_per_model=self.test_results_per_model, + model_outlet_uris=self._model_outlet_uris, + dataset_namespace=self._dataset_namespace, ) def run_subprocess(self, command: list[str], env: dict[str, str], cwd: str, **kwargs: Any) -> Any: @@ -240,7 +250,7 @@ def _push_skipped_xcom_for_model(self, ti: Any, unique_id: str) -> None: Uses the unified ``*_status`` XCom key that consumer sensors already poll. """ uid_key = unique_id.replace(".", "__") - safe_xcom_push(task_instance=ti, key=f"{uid_key}_status", value="skipped") + safe_xcom_push(task_instance=ti, key=f"{uid_key}_status", value={"status": "skipped", "outlet_uris": []}) def _run_source_freshness(self, context: Context) -> None: """Run ``dbt source freshness`` via ``build_cmd`` and ``run_command``. @@ -303,7 +313,11 @@ def _push_source_freshness_results(self, context: Context) -> None: status = result.get("status") if unique_id and status: uid_key = unique_id.replace(".", "__") - safe_xcom_push(task_instance=ti, key=f"{uid_key}_status", value=status) + safe_xcom_push( + task_instance=ti, + key=f"{uid_key}_status", + value={"status": status, "outlet_uris": []}, + ) def _apply_source_freshness(self, context: Context) -> None: """Run source freshness, invoke the callback, and mark affected nodes as skipped.""" @@ -331,6 +345,10 @@ def _apply_source_freshness(self, context: Context) -> None: self._skipped_node_token(context, node_ids_to_skip) def execute(self, context: Context, **kwargs: Any) -> Any: + # Pre-compute the dataset namespace for per-model outlet URI generation. + self._dataset_namespace = get_dataset_namespace(self.profile_config) + self._model_outlet_uris.clear() + task_instance = context.get("ti") if task_instance is None: raise AirflowException("DbtProducerWatcherOperator expects a task instance in the execution context") @@ -382,6 +400,41 @@ def __init__( **kwargs, ) + def _emit_datasets(self, context: Context) -> None: + """Emit Airflow datasets for this consumer task's model using outlet URIs from the producer.""" + if not getattr(self, "emit_datasets", False): + return + outlet_uris = getattr(self, "_outlet_uris", []) + if not outlet_uris: + return + + from cosmos import settings + from cosmos.constants import AIRFLOW_VERSION + + if AIRFLOW_VERSION.major >= 3: + from airflow.sdk.definitions.asset import Asset + else: + from airflow.datasets import Dataset as Asset # type: ignore[no-redef] + + outlets = [Asset(uri=uri) for uri in outlet_uris] + logger.info("Emitting %d dataset(s) for model '%s': %s", len(outlets), self.model_unique_id, outlet_uris) + self.register_dataset([], outlets, context) + + if settings.enable_uri_xcom: + context["ti"].xcom_push(key="uri", value=outlet_uris) + + def execute(self, context: Context, **kwargs: Any) -> None: # type: ignore[override] + super().execute(context, **kwargs) + # If we reach here without deferring, the model succeeded — emit datasets + self._emit_datasets(context) + + def execute_complete(self, context: Context, event: dict[str, Any]) -> None: + # Extract outlet URIs from trigger event before parent handles status + self._outlet_uris = event.get("outlet_uris", []) + super().execute_complete(context, event) + # If we reach here without raising, the model succeeded — emit datasets + self._emit_datasets(context) + # This Operator does not seem to make sense for this particular execution mode, since build is executed by the producer task. # That said, it is important to raise an exception if users attempt to use TestBehavior.BUILD, until we have a better experience. diff --git a/docs/guides/run_dbt/airflow-worker/watcher-execution-mode.rst b/docs/guides/run_dbt/airflow-worker/watcher-execution-mode.rst index 58eef23043..a412103a1c 100644 --- a/docs/guides/run_dbt/airflow-worker/watcher-execution-mode.rst +++ b/docs/guides/run_dbt/airflow-worker/watcher-execution-mode.rst @@ -350,7 +350,14 @@ The ``TestBehavior.NONE`` and ``TestBehavior.AFTER_ALL`` behave similarly to ``E Airflow Datasets and Assets ''''''''''''''''''''''''''' -While the ``ExecutionMode.WATCHER`` supports the ``emit_datasets`` parameter, the Airflow Datasets and Assets are emitted from the ``DbtProducerWatcherOperator`` task instead of the consumer tasks, as done for other Cosmos' execution modes. +.. versionchanged:: 1.14.0 + +The ``ExecutionMode.WATCHER`` supports the ``emit_datasets`` parameter. Starting with Cosmos 1.14.0, +each **consumer** sensor task emits the Airflow Dataset (Asset) for the dbt model it watches, matching +the per-model granularity of other execution modes. The **producer** task does not emit datasets. + +For details on URI patterns and how dataset emission differs between execution modes, see +:ref:`data-aware-scheduling`. Source freshness nodes '''''''''''''''''''''' diff --git a/docs/guides/run_dbt/customization/scheduling.rst b/docs/guides/run_dbt/customization/scheduling.rst index 201b0ff0b2..9f3e102156 100644 --- a/docs/guides/run_dbt/customization/scheduling.rst +++ b/docs/guides/run_dbt/customization/scheduling.rst @@ -131,6 +131,122 @@ From Cosmos 1.7 and Airflow 2.10, it is also possible to trigger DAGs be to be r t3 +Dataset URI Patterns +'''''''''''''''''''' + +Cosmos generates dataset URIs using the `OpenLineage naming convention `_. +Each URI is composed of a **namespace** (derived from the database connection) and a **name** (derived from the dbt model's +database, schema, and table/alias). + +The general format is: + +- **Airflow 2**: ``/..`` +- **Airflow 3**: ``///
`` + +The namespace depends on your database adapter: + ++---------------+-----------------------------------------------+------------------------------+ +| Adapter | Namespace format | Example | ++===============+===============================================+==============================+ +| Postgres | ``postgres://:`` | ``postgres://myhost:5432`` | ++---------------+-----------------------------------------------+------------------------------+ +| Snowflake | ``snowflake://`` | ``snowflake://xy12345`` | ++---------------+-----------------------------------------------+------------------------------+ +| BigQuery | ``bigquery`` | ``bigquery`` | ++---------------+-----------------------------------------------+------------------------------+ +| Redshift | ``redshift://:`` | ``redshift://myhost:5439`` | ++---------------+-----------------------------------------------+------------------------------+ +| Databricks | ``databricks://`` | ``databricks://myhost`` | ++---------------+-----------------------------------------------+------------------------------+ +| Trino | ``trino://:`` | ``trino://myhost:8080`` | ++---------------+-----------------------------------------------+------------------------------+ +| SQL Server | ``mssql://:`` | ``mssql://myserver:1433`` | ++---------------+-----------------------------------------------+------------------------------+ +| Athena | ``awsathena://athena..amazonaws.com`` | ``awsathena://athena...`` | ++---------------+-----------------------------------------------+------------------------------+ +| DuckDB | ``duckdb://`` | ``duckdb:///tmp/my.db`` | ++---------------+-----------------------------------------------+------------------------------+ +| ClickHouse | ``clickhouse://:`` | ``clickhouse://myhost:8123`` | ++---------------+-----------------------------------------------+------------------------------+ +| Spark | ``spark://:`` | ``spark://myhost:10001`` | ++---------------+-----------------------------------------------+------------------------------+ + +.. note:: + + Dataset emission is only supported for adapters that the `OpenLineage dbt integration + `_ supports. + Adapters like MySQL, Oracle, Vertica, Exasol, and Teradata do not emit datasets in any execution mode. + +For example, a Postgres model ``customers`` in database ``analytics`` and schema ``public`` running on +host ``myhost:5432`` produces: + +- **Airflow 2**: ``postgres://myhost:5432/analytics.public.customers`` +- **Airflow 3**: ``postgres://myhost:5432/analytics/public/customers`` + +To schedule a downstream DAG on this dataset, use the URI directly: + +.. code-block:: python + + from airflow.datasets import Dataset + + # Airflow 2 + my_dataset = Dataset(uri="postgres://myhost:5432/analytics.public.customers") + + # Airflow 3 + my_dataset = Dataset(uri="postgres://myhost:5432/analytics/public/customers") + +You can also inspect the URIs emitted by a Cosmos DAG at runtime by enabling the +``enable_uri_xcom`` setting — see `Emitting Dataset URIs as XCom`_ below. + + +How Dataset Emission Differs by Execution Mode +''''''''''''''''''''''''''''''''''''''''''''''' + ++-------------------------------+---------------------------------------------+ +| Execution Mode | How datasets are emitted | ++===============================+=============================================+ +| ``LOCAL``, ``VIRTUALENV`` | Each dbt task emits its own datasets | +| | using the OpenLineage library to parse | +| | dbt artifacts (``manifest.json``, | +| | ``run_results.json``) after execution. | ++-------------------------------+---------------------------------------------+ +| ``WATCHER`` | The **producer** task runs ``dbt build`` | +| | once but does **not** emit datasets. | +| | Instead, it reads the manifest to compute | +| | per-model URIs and passes them to | +| | **consumer** sensor tasks via XCom. | +| | Each consumer emits the dataset for its | +| | own model upon successful completion. | +| | This preserves per-model granularity in | +| | Airflow's data-aware scheduling. | ++-------------------------------+---------------------------------------------+ +| ``AIRFLOW_ASYNC`` | Each task emits its own datasets after | +| | the async SQL execution completes. | +| | URIs are constructed directly from the | +| | task's BigQuery project, dataset, and | +| | table name (BigQuery only). | ++-------------------------------+---------------------------------------------+ + +.. important:: + + **WATCHER mode caveats for dataset emission:** + + In ``LOCAL`` and ``VIRTUALENV`` modes, Cosmos delegates namespace resolution entirely to the + `OpenLineage dbt integration `_ + library, which parses dbt artifacts at runtime. This means URIs always reflect the latest + OL naming conventions. + + In ``WATCHER`` mode, Cosmos derives the dataset namespace from your ``ProfileConfig`` + using its own adapter-to-namespace mapping. This mapping was aligned with the OpenLineage + library as of Cosmos 1.14.0. If a future OpenLineage release changes its namespace logic + (e.g. changes a scheme or adds a new adapter), there is a window where ``WATCHER`` mode + may produce slightly different URIs than ``LOCAL`` mode until Cosmos is updated to match. + + If you depend on exact URI parity between execution modes, verify that both modes produce + the same URIs after upgrading either Cosmos or the OpenLineage library. You can inspect + emitted URIs by enabling the ``enable_uri_xcom`` setting (see `Emitting Dataset URIs as XCom`_). + + Known Limitations ''''''''''''''''' diff --git a/tests/hooks/test_subprocess.py b/tests/hooks/test_subprocess.py index 6611a81b84..aec4e4a532 100644 --- a/tests/hooks/test_subprocess.py +++ b/tests/hooks/test_subprocess.py @@ -98,7 +98,9 @@ def test_store_dbt_resource_status_from_log_param(status, context, expect_xcom_p store_dbt_resource_status_from_log(line, {"context": context}, tests_per_model={}, test_results_per_model={}) if expect_xcom_push: mock_push.assert_called_with( - task_instance=context["ti"], key="model__jaffle_shop__stg_orders_status", value=status + task_instance=context["ti"], + key="model__jaffle_shop__stg_orders_status", + value={"status": status, "outlet_uris": []}, ) assert mock_push.call_count == 1 else: diff --git a/tests/operators/_watcher/test_triggerer.py b/tests/operators/_watcher/test_triggerer.py index 65276115a1..05092e776a 100644 --- a/tests/operators/_watcher/test_triggerer.py +++ b/tests/operators/_watcher/test_triggerer.py @@ -82,11 +82,25 @@ async def runner(*args, **kwargs): none_result = await self.trigger.get_xcom_val_af2("test_key") assert none_result is None + async def test_get_node_status_with_dict_xcom(self): + """When XCom value is a dict, extract status and outlet_uris.""" + xcom_dict = {"status": "success", "outlet_uris": ["postgres://host:5432/db/schema/table"]} + with patch.object(self.trigger, "get_xcom_val", AsyncMock(return_value=xcom_dict)): + status = await self.trigger._get_node_status() + assert status == "success" + assert self.trigger._outlet_uris == ["postgres://host:5432/db/schema/table"] + + async def test_get_node_status_with_none_xcom(self): + """When XCom value is None (not yet pushed), return None.""" + with patch.object(self.trigger, "get_xcom_val", AsyncMock(return_value=None)): + status = await self.trigger._get_node_status() + assert status is None + @pytest.mark.parametrize( "xcom_val, expected_status, expected_compiled_sql", [ - ("failed", "failed", None), - ("success", "success", "SELECT * FROM table"), + ({"status": "failed", "outlet_uris": []}, "failed", None), + ({"status": "success", "outlet_uris": []}, "success", "SELECT * FROM table"), (None, None, None), ], ) @@ -135,11 +149,15 @@ async def test_get_xcom_val_branches(self, airflow_version, expected_val): assert val == "af3" @pytest.mark.parametrize( - "dbt_node_status, producer_state, expected", + "xcom_status_val, producer_state, expected", [ - ("success", "running", {"status": "success"}), - ("skipped", "running", {"status": "skipped"}), - ("failed", "running", {"status": "failed", "reason": WatcherEventReason.NODE_FAILED}), + ({"status": "success", "outlet_uris": []}, "running", {"status": "success"}), + ({"status": "skipped", "outlet_uris": []}, "running", {"status": "skipped"}), + ( + {"status": "failed", "outlet_uris": []}, + "running", + {"status": "failed", "reason": WatcherEventReason.NODE_FAILED}, + ), (None, "failed", {"status": "failed", "reason": WatcherEventReason.PRODUCER_FAILED}), (None, "success", {"status": "success", "reason": WatcherEventReason.NODE_NOT_RUN}), ], @@ -147,7 +165,7 @@ async def test_get_xcom_val_branches(self, airflow_version, expected_val): @patch("cosmos.operators._watcher.triggerer.WatcherTrigger._log_startup_events") @patch("cosmos.operators._watcher.triggerer._log_dbt_event") async def test_run_various_outcomes( - self, mock_dbt_event, mock_startup_events, dbt_node_status, producer_state, expected + self, mock_dbt_event, mock_startup_events, xcom_status_val, producer_state, expected ): async def fake_get_xcom_val(key): if key == _DBT_STARTUP_EVENTS_XCOM_KEY: @@ -155,7 +173,7 @@ async def fake_get_xcom_val(key): if key.endswith("_compiled_sql"): return None if key.endswith("_status"): - return dbt_node_status + return xcom_status_val return None with ( @@ -165,6 +183,31 @@ async def fake_get_xcom_val(key): events = [event async for event in self.trigger.run()] assert events[0].payload == expected + @patch("cosmos.operators._watcher.triggerer.WatcherTrigger._log_startup_events") + @patch("cosmos.operators._watcher.triggerer._log_dbt_event") + async def test_run_success_with_outlet_uris(self, mock_dbt_event, mock_startup_events): + """When model succeeds and outlet URIs are present, TriggerEvent includes them.""" + outlet_uris = ["postgres://host:5432/db/schema/table"] + xcom_dict = {"status": "success", "outlet_uris": outlet_uris} + + async def fake_get_xcom_val(key): + if key == _DBT_STARTUP_EVENTS_XCOM_KEY: + return _STARTUP_EVENTS + if key.endswith("_compiled_sql"): + return None + if key.endswith("_status"): + return xcom_dict + return None + + with ( + patch.object(self.trigger, "get_xcom_val", side_effect=fake_get_xcom_val), + patch.object(self.trigger, "_get_producer_task_status", AsyncMock(return_value="running")), + ): + events = [event async for event in self.trigger.run()] + assert len(events) == 1 + assert events[0].payload["status"] == "success" + assert events[0].payload["outlet_uris"] == outlet_uris + @pytest.mark.asyncio async def test_get_producer_task_status_airflow2(self): fetcher = MagicMock(return_value="failed") diff --git a/tests/operators/test_watcher.py b/tests/operators/test_watcher.py index 09ef0f774b..62f78cb647 100644 --- a/tests/operators/test_watcher.py +++ b/tests/operators/test_watcher.py @@ -274,6 +274,54 @@ def test_dbt_consumer_watcher_sensor_execute_complete(event, expected_message): assert str(excinfo.value) == expected_message +class TestConsumerEmitDatasets: + """Tests for DbtConsumerWatcherSensor._emit_datasets.""" + + def _make_sensor(self, emit_datasets=True): + sensor = DbtConsumerWatcherSensor( + project_dir=".", + profile_config=profile_config, + task_id="test_sensor", + extra_context={"dbt_node_config": {"unique_id": "model.pkg.my_model"}}, + deferrable=False, + ) + sensor.emit_datasets = emit_datasets + return sensor + + def test_emit_datasets_skipped_when_disabled(self): + sensor = self._make_sensor(emit_datasets=False) + ctx = {"ti": _MockTI()} + # Should not raise + sensor._emit_datasets(ctx) + + def test_emit_datasets_skipped_when_no_uris(self): + sensor = self._make_sensor() + sensor._outlet_uris = [] + ctx = {"ti": _MockTI()} + sensor._emit_datasets(ctx) + + @patch.object(DbtConsumerWatcherSensor, "register_dataset") + def test_emit_datasets_calls_register(self, mock_register): + sensor = self._make_sensor() + sensor._outlet_uris = ["postgres://host:5432/db/schema/table"] + ti = _MockTI() + ctx = {"ti": ti} + sensor._emit_datasets(ctx) + mock_register.assert_called_once() + args = mock_register.call_args + assert len(args[0][1]) == 1 # one outlet + + @patch("cosmos.settings.enable_uri_xcom", True) + @patch.object(DbtConsumerWatcherSensor, "register_dataset") + def test_emit_datasets_pushes_xcom_when_enabled(self, mock_register): + sensor = self._make_sensor() + sensor._outlet_uris = ["postgres://host:5432/db/schema/table"] + ti = _MockTI() + ctx = {"ti": ti} + sensor._emit_datasets(ctx) + assert ti.store.get("uri") == ["postgres://host:5432/db/schema/table"] + + class TestStoreDbtStatusFromLog: """Tests for store_dbt_resource_status_from_log and _process_log_line_callable.""" @@ -286,7 +334,7 @@ def test_store_dbt_resource_status_from_log_success(self): store_dbt_resource_status_from_log(log_line, {"context": ctx}, tests_per_model={}, test_results_per_model={}) - assert ti.store.get("model__pkg__my_model_status") == "success" + assert ti.store.get("model__pkg__my_model_status") == {"status": "success", "outlet_uris": []} def test_store_dbt_resource_status_from_log_failed(self): """Test that failed status is correctly parsed and stored in XCom.""" @@ -297,7 +345,7 @@ def test_store_dbt_resource_status_from_log_failed(self): store_dbt_resource_status_from_log(log_line, {"context": ctx}, tests_per_model={}, test_results_per_model={}) - assert ti.store.get("model__pkg__failed_model_status") == "failed" + assert ti.store.get("model__pkg__failed_model_status") == {"status": "failed", "outlet_uris": []} def test_store_dbt_resource_status_from_log_ignores_other_statuses(self): """Test that statuses other than success/failed are ignored.""" @@ -330,7 +378,7 @@ def test_store_dbt_resources_status_from_log_detects_passed_test_status(self): store_dbt_resource_status_from_log(log_line, {"context": ctx}, tests_per_model={}, test_results_per_model={}) - assert ti.store.get("test__pkg__my_test_status") == "pass" + assert ti.store.get("test__pkg__my_test_status") == {"status": "pass", "outlet_uris": []} def test_store_dbt_resource_status_from_log_detects_failed_test_status(self): """Test that a failed test status is correctly parsed and stored in XCom.""" @@ -350,7 +398,7 @@ def test_store_dbt_resource_status_from_log_detects_failed_test_status(self): store_dbt_resource_status_from_log(log_line, {"context": ctx}, tests_per_model={}, test_results_per_model={}) - assert ti.store.get("test__pkg__my_test_status") == "fail" + assert ti.store.get("test__pkg__my_test_status") == {"status": "fail", "outlet_uris": []} def test_store_dbt_resource_status_from_log_aggregates_test_results_when_tests_per_model_provided(self): """When tests_per_model is non-empty and a test node finishes, the function should @@ -565,8 +613,8 @@ def test_process_log_line_callable_integration_with_subprocess_pattern(self): for line in log_lines: op._process_log_line_callable(line, kwargs) - assert ti.store.get("model__pkg__model_a_status") == "success" - assert ti.store.get("model__pkg__model_b_status") == "failed" + assert ti.store.get("model__pkg__model_a_status") == {"status": "success", "outlet_uris": []} + assert ti.store.get("model__pkg__model_b_status") == {"status": "failed", "outlet_uris": []} assert len(ti.store) == 2 # Only success and failed statuses are stored def test_store_dbt_resource_status_from_log_pushes_compiled_sql_for_models(self, tmp_path): @@ -596,7 +644,7 @@ def test_store_dbt_resource_status_from_log_pushes_compiled_sql_for_models(self, store_dbt_resource_status_from_log(log_line, {"context": ctx, "project_dir": str(tmp_path)}) - assert ti.store.get("model__pkg__my_model_status") == "success" + assert ti.store.get("model__pkg__my_model_status") == {"status": "success", "outlet_uris": []} assert ti.store.get("model__pkg__my_model_compiled_sql") == "SELECT * FROM orders WHERE status = 'completed'" def test_store_dbt_resource_status_from_log_no_compiled_sql_for_non_models(self, tmp_path): @@ -620,7 +668,7 @@ def test_store_dbt_resource_status_from_log_no_compiled_sql_for_non_models(self, store_dbt_resource_status_from_log(log_line, {"context": ctx, "project_dir": str(tmp_path)}) - assert ti.store.get("seed__pkg__my_seed_status") == "success" + assert ti.store.get("seed__pkg__my_seed_status") == {"status": "success", "outlet_uris": []} assert "seed__pkg__my_seed_compiled_sql" not in ti.store def test_store_dbt_resource_status_from_log_pushes_compiled_sql_on_failure(self, tmp_path): @@ -649,7 +697,7 @@ def test_store_dbt_resource_status_from_log_pushes_compiled_sql_on_failure(self, store_dbt_resource_status_from_log(log_line, {"context": ctx, "project_dir": str(tmp_path)}) - assert ti.store.get("model__pkg__failed_model_status") == "failed" + assert ti.store.get("model__pkg__failed_model_status") == {"status": "failed", "outlet_uris": []} assert ti.store.get("model__pkg__failed_model_compiled_sql") == "SELECT * FROM orders WHERE invalid_column = 1" @@ -944,8 +992,8 @@ def test_poke_success(self): ti = MagicMock() ti.try_number = 1 - # xcom_pull calls: _log_startup_events=None, _get_node_status="success", compiled_sql=None, _dbt_event=None - ti.xcom_pull.side_effect = [None, "success", None, None] + # xcom_pull calls: _log_startup_events=None, _get_node_status=dict, compiled_sql=None, _dbt_event=None + ti.xcom_pull.side_effect = [None, {"status": "success", "outlet_uris": []}, None, None] context = self.make_context(ti) result = sensor.poke(context) @@ -960,8 +1008,8 @@ def test_poke_subprocess_mode_extracts_compiled_sql_from_xcom(self, mock_overrid ti = MagicMock() ti.try_number = 1 - # xcom_pull calls: _log_startup_events=None, _get_node_status="success", compiled_sql="SELECT * FROM orders", _dbt_event=None - ti.xcom_pull.side_effect = [None, "success", "SELECT * FROM orders", None] + # xcom_pull calls: _log_startup_events=None, _get_node_status=dict, compiled_sql="SELECT * FROM orders", _dbt_event=None + ti.xcom_pull.side_effect = [None, {"status": "success", "outlet_uris": []}, "SELECT * FROM orders", None] context = self.make_context(ti) assert sensor.compiled_sql == "" # Initially empty @@ -976,8 +1024,8 @@ def test_poke_failure(self): ti = MagicMock() ti.try_number = 1 - # xcom_pull calls: _log_startup_events=None, _get_node_status="failed", compiled_sql=None, _dbt_event=None - ti.xcom_pull.side_effect = [None, "failed", None, None] + # xcom_pull calls: _log_startup_events=None, _get_node_status=dict, compiled_sql=None, _dbt_event=None + ti.xcom_pull.side_effect = [None, {"status": "failed", "outlet_uris": []}, None, None] context = self.make_context(ti) with pytest.raises(AirflowException): @@ -1219,10 +1267,9 @@ async def test_parse_dbt_node_status_and_compiled_sql_subprocess_mode(self): """Test that compiled_sql is extracted from XCom in subprocess mode.""" trigger = self.make_trigger() - # Mock get_xcom_val to return status and compiled_sql async def mock_get_xcom_val(key): if key == "model__pkg__my_model_status": - return "success" + return {"status": "success", "outlet_uris": []} elif key == "model__pkg__my_model_compiled_sql": return "SELECT * FROM orders" return None @@ -1239,10 +1286,9 @@ async def test_parse_dbt_node_status_and_compiled_sql_subprocess_no_compiled_sql """Test that missing compiled_sql is handled gracefully in subprocess mode.""" trigger = self.make_trigger() - # Mock get_xcom_val to return only status async def mock_get_xcom_val(key): if key == "model__pkg__my_model_status": - return "success" + return {"status": "success", "outlet_uris": []} return None trigger.get_xcom_val = mock_get_xcom_val @@ -2054,7 +2100,9 @@ def test_push_skipped_xcom_for_model(self): producer = self._make_producer() ti = MagicMock() producer._push_skipped_xcom_for_model(ti, "model.pkg.my_model") - ti.xcom_push.assert_called_once_with(key="model__pkg__my_model_status", value="skipped") + ti.xcom_push.assert_called_once_with( + key="model__pkg__my_model_status", value={"status": "skipped", "outlet_uris": []} + ) def test_skipped_node_token_updates_exclude(self): producer = self._make_producer() diff --git a/tests/operators/test_watcher_kubernetes_unit.py b/tests/operators/test_watcher_kubernetes_unit.py index feb4766025..2d5f8855b8 100644 --- a/tests/operators/test_watcher_kubernetes_unit.py +++ b/tests/operators/test_watcher_kubernetes_unit.py @@ -162,7 +162,7 @@ def test_first_execution_behaves_as_base_consumer_sensor(mock_startup_events): ti = MagicMock() ti.try_number = 1 - ti.xcom_pull.return_value = "success" + ti.xcom_pull.return_value = {"status": "success", "outlet_uris": []} context = make_context(ti) result = sensor.poke(context) diff --git a/tests/test_dataset.py b/tests/test_dataset.py index 849d41df52..ac12f60b71 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -1,7 +1,10 @@ +import json from datetime import datetime +from unittest.mock import MagicMock, patch import pytest from airflow import DAG +from packaging.version import Version try: # Airflow 3.1 onwards @@ -9,7 +12,12 @@ except ImportError: from airflow.utils.task_group import TaskGroup -from cosmos.dataset import get_dataset_alias_name +from cosmos.dataset import ( + compute_model_outlet_uris, + construct_dataset_uri, + get_dataset_alias_name, + get_dataset_namespace, +) START_DATE = datetime(2024, 4, 16) example_dag = DAG("dag", start_date=START_DATE) @@ -53,3 +61,252 @@ ) def test_get_dataset_alias_name(dag, task_group, task_id, result_identifier): assert get_dataset_alias_name(dag, task_group, task_id) == result_identifier + + +# --- Tests for get_dataset_namespace --- + + +class TestGetDatasetNamespace: + def test_postgres_with_profile_mapping(self): + profile_config = MagicMock() + profile_config.profile_mapping = MagicMock() + profile_config.profile_mapping.dbt_profile_type = "postgres" + profile_config.profile_mapping.profile = {"host": "myhost", "port": 5432} + assert get_dataset_namespace(profile_config) == "postgres://myhost:5432" + + def test_bigquery_with_profile_mapping(self): + profile_config = MagicMock() + profile_config.profile_mapping = MagicMock() + profile_config.profile_mapping.dbt_profile_type = "bigquery" + profile_config.profile_mapping.profile = {"project": "my-project"} + assert get_dataset_namespace(profile_config) == "bigquery" + + def test_snowflake_with_profile_mapping(self): + profile_config = MagicMock() + profile_config.profile_mapping = MagicMock() + profile_config.profile_mapping.dbt_profile_type = "snowflake" + profile_config.profile_mapping.profile = {"account": "xy12345.us-east-1"} + # fix_account_name normalizes account locator format (appends cloud suffix) + assert get_dataset_namespace(profile_config) == "snowflake://xy12345.us-east-1.aws" + + def test_unknown_adapter_returns_none(self): + """Unsupported adapters return None so dataset emission is skipped.""" + profile_config = MagicMock() + profile_config.profile_mapping = MagicMock() + profile_config.profile_mapping.dbt_profile_type = "custom_adapter" + profile_config.profile_mapping.profile = {} + assert get_dataset_namespace(profile_config) is None + + def test_incomplete_namespace_returns_none(self): + """Adapters with missing required fields produce scheme-only namespaces; should return None.""" + profile_config = MagicMock() + profile_config.profile_mapping = MagicMock() + profile_config.profile_mapping.dbt_profile_type = "databricks" + profile_config.profile_mapping.profile = {} # no host + assert get_dataset_namespace(profile_config) is None + + def test_profiles_yml_filepath(self, tmp_path): + profiles_yml = tmp_path / "profiles.yml" + profiles_yml.write_text(""" +my_profile: + outputs: + dev: + type: postgres + host: dbhost + port: 5433 + user: user + pass: pass + dbname: mydb + schema: public +""") + profile_config = MagicMock() + profile_config.profile_mapping = None + profile_config.profiles_yml_filepath = str(profiles_yml) + profile_config.profile_name = "my_profile" + profile_config.target_name = "dev" + assert get_dataset_namespace(profile_config) == "postgres://dbhost:5433" + + def test_returns_none_on_error(self): + profile_config = MagicMock() + profile_config.profile_mapping = None + profile_config.profiles_yml_filepath = "/nonexistent/path.yml" + assert get_dataset_namespace(profile_config) is None + + def test_returns_none_when_no_profile_info(self): + """When neither profile_mapping nor profiles_yml_filepath is set, returns None.""" + profile_config = MagicMock() + profile_config.profile_mapping = None + profile_config.profiles_yml_filepath = None + assert get_dataset_namespace(profile_config) is None + + def test_spark_with_thrift_method(self): + profile_config = MagicMock() + profile_config.profile_mapping = MagicMock() + profile_config.profile_mapping.dbt_profile_type = "spark" + profile_config.profile_mapping.profile = {"host": "spark-host", "method": "thrift"} + assert get_dataset_namespace(profile_config) == "spark://spark-host:10001" + + def test_spark_with_http_method(self): + profile_config = MagicMock() + profile_config.profile_mapping = MagicMock() + profile_config.profile_mapping.dbt_profile_type = "spark" + profile_config.profile_mapping.profile = {"host": "spark-host", "method": "http"} + assert get_dataset_namespace(profile_config) == "spark://spark-host:443" + + def test_spark_with_explicit_port(self): + profile_config = MagicMock() + profile_config.profile_mapping = MagicMock() + profile_config.profile_mapping.dbt_profile_type = "spark" + profile_config.profile_mapping.profile = {"host": "spark-host", "port": 9999} + assert get_dataset_namespace(profile_config) == "spark://spark-host:9999" + + def test_spark_with_unknown_method(self): + profile_config = MagicMock() + profile_config.profile_mapping = MagicMock() + profile_config.profile_mapping.dbt_profile_type = "spark" + profile_config.profile_mapping.profile = {"host": "spark-host", "method": "session"} + assert get_dataset_namespace(profile_config) == "spark://spark-host:10000" + + def test_glue_with_account_id(self): + profile_config = MagicMock() + profile_config.profile_mapping = MagicMock() + profile_config.profile_mapping.dbt_profile_type = "glue" + profile_config.profile_mapping.profile = {"region": "eu-west-1", "account_id": "123456789"} + assert get_dataset_namespace(profile_config) == "arn:aws:glue:eu-west-1:123456789" + + def test_glue_with_role_arn(self): + profile_config = MagicMock() + profile_config.profile_mapping = MagicMock() + profile_config.profile_mapping.dbt_profile_type = "glue" + profile_config.profile_mapping.profile = { + "region": "us-east-1", + "role_arn": "arn:aws:iam::987654321:role/my-role", + } + assert get_dataset_namespace(profile_config) == "arn:aws:glue:us-east-1:987654321" + + def test_glue_without_account_id_or_role_arn(self): + """When neither account_id nor role_arn is available, returns None.""" + profile_config = MagicMock() + profile_config.profile_mapping = MagicMock() + profile_config.profile_mapping.dbt_profile_type = "glue" + profile_config.profile_mapping.profile = {"region": "us-east-1"} + assert get_dataset_namespace(profile_config) is None + + @patch("cosmos.dataset._resolve_snowflake_account") + def test_snowflake_calls_resolve(self, mock_resolve): + mock_resolve.return_value = "normalized-account" + profile_config = MagicMock() + profile_config.profile_mapping = MagicMock() + profile_config.profile_mapping.dbt_profile_type = "snowflake" + profile_config.profile_mapping.profile = {"account": "raw-account"} + result = get_dataset_namespace(profile_config) + assert result == "snowflake://normalized-account" + mock_resolve.assert_called_once_with("raw-account") + + +class TestResolveSnowflakeAccount: + def test_with_openlineage_available(self): + from cosmos.dataset import _resolve_snowflake_account + + # OL is installed in test env, so fix_account_name should be used + result = _resolve_snowflake_account("xy12345.us-east-1") + assert result == "xy12345.us-east-1.aws" + + def test_without_openlineage(self): + """Falls back to returning the raw account when OL is not installed.""" + from cosmos.dataset import _resolve_snowflake_account + + original_import = __import__ + + def import_raising_for_openlineage(name, globals=None, locals=None, fromlist=(), level=0): + if name == "openlineage.common.provider.snowflake": + raise ImportError("No module named 'openlineage.common.provider.snowflake'") + return original_import(name, globals, locals, fromlist, level) + + with patch("builtins.__import__", side_effect=import_raising_for_openlineage): + result = _resolve_snowflake_account("xy12345") + + assert result == "xy12345" + + +# --- Tests for construct_dataset_uri --- + + +class TestConstructDatasetUri: + @patch("cosmos.dataset.AIRFLOW_VERSION", new=Version("2.10.0")) + @patch("cosmos.dataset.settings") + def test_airflow2_default_uri(self, mock_settings): + mock_settings.use_dataset_airflow3_uri_standard = False + uri = construct_dataset_uri("postgres://host:5432", "mydb.public.customers") + assert uri == "postgres://host:5432/mydb.public.customers" + + @patch("cosmos.dataset.AIRFLOW_VERSION", new=Version("2.10.0")) + @patch("cosmos.dataset.settings") + def test_airflow2_with_airflow3_standard(self, mock_settings): + mock_settings.use_dataset_airflow3_uri_standard = True + uri = construct_dataset_uri("postgres://host:5432", "mydb.public.customers") + assert uri == "postgres://host:5432/mydb/public/customers" + + @patch("cosmos.dataset.AIRFLOW_VERSION", new=Version("3.0.0")) + def test_airflow3_uri(self): + uri = construct_dataset_uri("postgres://host:5432", "mydb.public.customers") + assert uri == "postgres://host:5432/mydb/public/customers" + + +# --- Tests for compute_model_outlet_uris --- + + +class TestComputeModelOutletUris: + def test_reads_manifest_and_computes_uris(self, tmp_path): + manifest = { + "nodes": { + "model.jaffle_shop.customers": { + "resource_type": "model", + "database": "postgres", + "schema": "public", + "alias": "customers", + }, + "seed.jaffle_shop.raw_orders": { + "resource_type": "seed", + "database": "postgres", + "schema": "public", + "alias": "raw_orders", + }, + "test.jaffle_shop.not_null": { + "resource_type": "test", + "database": "postgres", + "schema": "public", + "alias": "not_null", + }, + } + } + manifest_path = tmp_path / "manifest.json" + manifest_path.write_text(json.dumps(manifest)) + + result = compute_model_outlet_uris(str(manifest_path), "postgres://host:5432") + # Models and seeds should have URIs, tests should not + assert "model.jaffle_shop.customers" in result + assert "seed.jaffle_shop.raw_orders" in result + assert "test.jaffle_shop.not_null" not in result + assert len(result) == 2 + + def test_missing_manifest_returns_empty(self): + result = compute_model_outlet_uris("/nonexistent/manifest.json", "postgres://host:5432") + assert result == {} + + def test_skips_nodes_with_missing_fields(self, tmp_path): + manifest = { + "nodes": { + "model.pkg.incomplete": { + "resource_type": "model", + "database": "", + "schema": "public", + "alias": "incomplete", + }, + } + } + manifest_path = tmp_path / "manifest.json" + manifest_path.write_text(json.dumps(manifest)) + + result = compute_model_outlet_uris(str(manifest_path), "postgres://host:5432") + assert result == {}