diff --git a/.airflow-registry.yaml b/.airflow-registry.yaml new file mode 100644 index 0000000000..88fe9e7cc9 --- /dev/null +++ b/.airflow-registry.yaml @@ -0,0 +1,219 @@ +# .airflow-registry.yaml — metadata for the Apache Airflow Provider Registry +# Place this file in the root of your provider repository. +# Docs: https://github.com/apache/airflow/blob/main/dev/registry/README.md + +package-name: astronomer-cosmos +display-name: Cosmos (dbt) +description: Run dbt projects as Airflow DAGs and Task Groups. Supports both dbt Core and dbt Fusion. +logo: https://raw.githubusercontent.com/astronomer/astronomer-cosmos/main/docs/_static/cosmos-logo.svg +docs-url: https://astronomer.github.io/astronomer-cosmos +source-url: https://github.com/astronomer/astronomer-cosmos + +operators: + # Execute dbt commands using a local dbt installation, either within the Airflow Python virtual environment or as a standalone dbt binary. + - module: cosmos.operators.local.DbtRunLocalOperator + description: Run `dbt run` using a local dbt installation + - module: cosmos.operators.local.DbtTestLocalOperator + description: Run `dbt test` using a local dbt installation + - module: cosmos.operators.local.DbtSeedLocalOperator + description: Run `dbt seed` using a local dbt installation + - module: cosmos.operators.local.DbtSnapshotLocalOperator + description: Run `dbt snapshot` using a local dbt installation + - module: cosmos.operators.local.DbtCompileLocalOperator + description: Run `dbt compile` using a local dbt installation + - module: cosmos.operators.local.DbtLSLocalOperator + description: Run `dbt ls` using a local dbt installation + - module: cosmos.operators.local.DbtRunOperationLocalOperator + description: Run `dbt run-operation` using a local dbt installation + - module: cosmos.operators.local.DbtBuildLocalOperator + description: Run `dbt build` using a local dbt installation + - module: cosmos.operators.local.DbtCloneLocalOperator + description: Run `dbt clone` using a local dbt installation + - module: cosmos.operators.local.DbtSourceLocalOperator + description: Run `dbt source freshness` using a local dbt installation + - module: cosmos.operators.local.DbtDocsLocalOperator + description: Run `dbt docs generate` using a local dbt installation + - module: cosmos.operators.local.DbtDocsS3LocalOperator + description: Run `dbt docs generate` and upload the generated docs to AWS S3 + - module: cosmos.operators.local.DbtDocsAzureStorageLocalOperator + description: Run `dbt docs generate` and upload the generated docs to Azure Blob Storage + - module: cosmos.operators.local.DbtDocsGCSLocalOperator + description: Run `dbt docs generate` and upload the generated docs to Google Cloud Storage (GCS) + + # Execute dbt commands by running each command in its own (local) Docker container. + - module: cosmos.operators.docker.DbtBuildDockerOperator + description: Run `dbt build` inside a Docker container. + - module: cosmos.operators.docker.DbtCloneDockerOperator + description: Run `dbt clone` inside a Docker container. + - module: cosmos.operators.docker.DbtLSDockerOperator + description: Run `dbt ls` inside a Docker container. + - module: cosmos.operators.docker.DbtRunDockerOperator + description: Run `dbt run` inside a Docker container. + - module: cosmos.operators.docker.DbtRunOperationDockerOperator + description: Run `dbt run-operation` inside a Docker container. + - module: cosmos.operators.docker.DbtSeedDockerOperator + description: Run `dbt seed` inside a Docker container. + - module: cosmos.operators.docker.DbtSnapshotDockerOperator + description: Run `dbt snapshot` inside a Docker container. + - module: cosmos.operators.docker.DbtSourceDockerOperator + description: Run `dbt source freshness` inside a Docker container. + - module: cosmos.operators.docker.DbtTestDockerOperator + description: Run `dbt test` inside a Docker container. + + # Execute dbt commands via dedicated Kubernetes pods + - module: cosmos.operators.kubernetes.DbtBuildKubernetesOperator + description: Run `dbt build` in a Kubernetes pod. + - module: cosmos.operators.kubernetes.DbtCloneKubernetesOperator + description: Run `dbt clone` in a Kubernetes pod. + - module: cosmos.operators.kubernetes.DbtLSKubernetesOperator + description: Run `dbt ls` in a Kubernetes pod. + - module: cosmos.operators.kubernetes.DbtRunKubernetesOperator + description: Run `dbt run` in a Kubernetes pod. + - module: cosmos.operators.kubernetes.DbtRunOperationKubernetesOperator + description: Run `dbt run-operation` in a Kubernetes pod. + - module: cosmos.operators.kubernetes.DbtSeedKubernetesOperator + description: Run `dbt seed` in a Kubernetes pod. + - module: cosmos.operators.kubernetes.DbtSnapshotKubernetesOperator + description: Run `dbt snapshot` in a Kubernetes pod. + - module: cosmos.operators.kubernetes.DbtSourceKubernetesOperator + description: Run `dbt source freshness` in a Kubernetes pod. + - module: cosmos.operators.kubernetes.DbtTestKubernetesOperator + description: Run `dbt test` in a Kubernetes pod. + + # Execute dbt commands via isolated Python virtual environments created and managed by Cosmos + - module: cosmos.operators.virtualenv.DbtBuildVirtualenvOperator + description: Run `dbt build` inside a Cosmos-managed Python virtual environment. + - module: cosmos.operators.virtualenv.DbtCloneVirtualenvOperator + description: Run `dbt clone` inside a Cosmos-managed Python virtual environment. + - module: cosmos.operators.virtualenv.DbtDocsVirtualenvOperator + description: Run `dbt docs generate` inside a Cosmos-managed Python virtual environment. + - module: cosmos.operators.virtualenv.DbtLSVirtualenvOperator + description: Run `dbt ls` inside a Cosmos-managed Python virtual environment. + - module: cosmos.operators.virtualenv.DbtRunVirtualenvOperator + description: Run `dbt run` inside a Cosmos-managed Python virtual environment. + - module: cosmos.operators.virtualenv.DbtRunOperationVirtualenvOperator + description: Run `dbt run-operation` inside a Cosmos-managed Python virtual environment. + - module: cosmos.operators.virtualenv.DbtSeedVirtualenvOperator + description: Run `dbt seed` inside a Cosmos-managed Python virtual environment. + - module: cosmos.operators.virtualenv.DbtSnapshotVirtualenvOperator + description: Run `dbt snapshot` inside a Cosmos-managed Python virtual environment. + - module: cosmos.operators.virtualenv.DbtSourceVirtualenvOperator + description: Run `dbt source freshness` inside a Cosmos-managed Python virtual environment. + - module: cosmos.operators.virtualenv.DbtTestVirtualenvOperator + description: Run `dbt test` inside a Cosmos-managed Python virtual environment. + + # Pre-compile the project using dbt Core and use Airflow `BigQueryInsertJobOperator` to asynchronously run the transformations. + - module: cosmos.operators._asynchronous.SetupAsyncOperator + description: Monkey-patch `dbt run` command to pre-generate the dbt models SQL and export them to XCom + - module: cosmos.operators.airflow_async.DbtRunAirflowAsyncOperator + description: Retrieve the SQL from the setup task and execute each model natively to Airflow via a subclass of `BigQueryInsertJobOperator` + - module: cosmos.operators._asynchronous.TeardownAsyncOperator + description: Clear all the SQL files related to this DAG run from corresponding XComs + + # Run dbt commands as containers via the Amazon Elastic Container Service (AWS ECS) + - module: cosmos.operators.aws_ecs.DbtBuildAwsEcsOperator + description: Run `dbt build` as a task in AWS ECS. + - module: cosmos.operators.aws_ecs.DbtLSAwsEcsOperator + description: Run `dbt ls` as a task in AWS ECS. + - module: cosmos.operators.aws_ecs.DbtRunAwsEcsOperator + description: Run `dbt run` as a task in AWS ECS. + - module: cosmos.operators.aws_ecs.DbtRunOperationAwsEcsOperator + description: Run `dbt run-operation` as a task in AWS ECS. + - module: cosmos.operators.aws_ecs.DbtSeedAwsEcsOperator + description: Run `dbt seed` as a task in AWS ECS. + - module: cosmos.operators.aws_ecs.DbtSnapshotAwsEcsOperator + description: Run `dbt snapshot` as a task in AWS ECS. + - module: cosmos.operators.aws_ecs.DbtSourceAwsEcsOperator + description: Run `dbt source freshness` as a task in AWS ECS. + - module: cosmos.operators.aws_ecs.DbtTestAwsEcsOperator + description: Run `dbt test` as a task in AWS ECS. + + # Run dbt commands as Kubernetes pods via the Amazon Elastic Kubernetes Service (AWS EKS). + - module: cosmos.operators.aws_eks.DbtBuildAwsEksOperator + description: Run `dbt build` in a pod on AWS EKS. + - module: cosmos.operators.aws_eks.DbtCloneAwsEksOperator + description: Run `dbt clone` in a pod on AWS EKS. + - module: cosmos.operators.aws_eks.DbtLSAwsEksOperator + description: Run `dbt ls` in a pod on AWS EKS. + - module: cosmos.operators.aws_eks.DbtRunAwsEksOperator + description: Run `dbt run` in a pod on AWS EKS. + - module: cosmos.operators.aws_eks.DbtRunOperationAwsEksOperator + description: Run `dbt run-operation` in a pod on AWS EKS. + - module: cosmos.operators.aws_eks.DbtSeedAwsEksOperator + description: Run `dbt seed` in a pod on AWS EKS. + - module: cosmos.operators.aws_eks.DbtSnapshotAwsEksOperator + description: Run `dbt snapshot` in a pod on AWS EKS. + - module: cosmos.operators.aws_eks.DbtTestAwsEksOperator + description: Run `dbt test` in a pod on AWS EKS. + + # Run dbt commands as remote containers via the Azure Container Instance service. + - module: cosmos.operators.azure_container_instance.DbtBuildAzureContainerInstanceOperator + description: Run `dbt build` in an Azure Container Instance. + - module: cosmos.operators.azure_container_instance.DbtCloneAzureContainerInstanceOperator + description: Run `dbt clone` in an Azure Container Instance. + - module: cosmos.operators.azure_container_instance.DbtLSAzureContainerInstanceOperator + description: Run `dbt ls` in an Azure Container Instance. + - module: cosmos.operators.azure_container_instance.DbtRunAzureContainerInstanceOperator + description: Run `dbt run` in an Azure Container Instance. + - module: cosmos.operators.azure_container_instance.DbtRunOperationAzureContainerInstanceOperator + description: Run `dbt run-operation` in an Azure Container Instance. + - module: cosmos.operators.azure_container_instance.DbtSeedAzureContainerInstanceOperator + description: Run `dbt seed` in an Azure Container Instance. + - module: cosmos.operators.azure_container_instance.DbtSnapshotAzureContainerInstanceOperator + description: Run `dbt snapshot` in an Azure Container Instance. + - module: cosmos.operators.azure_container_instance.DbtSourceAzureContainerInstanceOperator + description: Run `dbt source freshness` in an Azure Container Instance. + - module: cosmos.operators.azure_container_instance.DbtTestAzureContainerInstanceOperator + description: Run `dbt test` in an Azure Container Instance. + + # Run dbt commands as Google Cloud (GCP) Platform Cloud Run Jobs + - module: cosmos.operators.gcp_cloud_run_job.DbtBuildGcpCloudRunJobOperator + description: Run `dbt build` as a Google Cloud Run Job. + - module: cosmos.operators.gcp_cloud_run_job.DbtCloneGcpCloudRunJobOperator + description: Run `dbt clone` as a Google Cloud Run Job. + - module: cosmos.operators.gcp_cloud_run_job.DbtLSGcpCloudRunJobOperator + description: Run `dbt ls` as a Google Cloud Run Job. + - module: cosmos.operators.gcp_cloud_run_job.DbtRunGcpCloudRunJobOperator + description: Run `dbt run` as a Google Cloud Run Job. + - module: cosmos.operators.gcp_cloud_run_job.DbtRunOperationGcpCloudRunJobOperator + description: Run `dbt run-operation` as a Google Cloud Run Job. + - module: cosmos.operators.gcp_cloud_run_job.DbtSeedGcpCloudRunJobOperator + description: Run `dbt seed` as a Google Cloud Run Job. + - module: cosmos.operators.gcp_cloud_run_job.DbtSnapshotGcpCloudRunJobOperator + description: Run `dbt snapshot` as a Google Cloud Run Job. + - module: cosmos.operators.gcp_cloud_run_job.DbtSourceGcpCloudRunJobOperator + description: Run `dbt source freshness` as a Google Cloud Run Job. + - module: cosmos.operators.gcp_cloud_run_job.DbtTestGcpCloudRunJobOperator + description: Run `dbt test` as a Google Cloud Run Job. + + # Watcher pattern (local) + # Run a single local `dbt build` or `dbt run` command per pipeline, via a producer task, and have Airflow deferrable sensors (consumers) tracking progress for individual dbt nodes. + - module: cosmos.operators.watcher.DbtProducerWatcherOperator + description: Run a single dbt command (build or run) for the entire dbt project and produce one XCom containing the status of each dbt node (model, seed, snapshot). + - module: cosmos.operators.watcher.DbtRunWatcherOperator + description: Monitor a dbt model using Airflow deferrable XCom sensors. On retry, executes `dbt run` for the specific model via `DbtRunLocalOperator`. + - module: cosmos.operators.watcher.DbtSeedWatcherOperator + description: Monitor a dbt seed using Airflow deferrable XCom sensors. On retry, executes `dbt seed` for the specific seed via `DbtSeedLocalOperator`. + - module: cosmos.operators.watcher.DbtSnapshotWatcherOperator + description: Monitor a dbt snapshot using Airflow deferrable XCom sensors. On retry, executes `dbt snapshot` for the specific snapshot via `DbtSnapshotLocalOperator`. + - module: cosmos.operators.watcher.DbtTestWatcherOperator + description: Runs `dbt test` using the local execution pattern, or acts as an `EmptyOperator` for `TestBehavior.AFTER_EACH`. + + # Watcher pattern (Kubernetes) + # Run a single `dbt build` or `dbt run` command per pipeline (producer), via a Kubernetes pod, and have Airflow deferrable sensors (consumers) tracking progress for individual dbt nodes. + - module: cosmos.operators.watcher_kubernetes.DbtProducerWatcherKubernetesOperator + description: Run a single dbt command (build or run) via a Kubernetes Pod Operator for the entire dbt project and produce one XCom containing the status of each dbt node (model, seed, snapshot). + - module: cosmos.operators.watcher_kubernetes.DbtRunWatcherKubernetesOperator + description: Monitor a dbt model using Airflow deferrable XCom sensors. On retry, executes `dbt run` for the specific model via `DbtRunKubernetesOperator`. + - module: cosmos.operators.watcher_kubernetes.DbtSeedWatcherKubernetesOperator + description: Monitor a dbt seed using Airflow deferrable XCom sensors. On retry, executes `dbt seed` for the specific seed via `DbtSeedKubernetesOperator`. + - module: cosmos.operators.watcher_kubernetes.DbtSnapshotWatcherKubernetesOperator + description: Monitor a dbt snapshot using Airflow deferrable XCom sensors. On retry, executes `dbt snapshot` for the specific snapshot via `DbtSnapshotKubernetesOperator`. + - module: cosmos.operators.watcher_kubernetes.DbtTestWatcherKubernetesOperator + description: Runs `dbt test` using the Kubernetes execution pattern, or acts as an `EmptyOperator` for `TestBehavior.AFTER_EACH`. + +other: + - module: cosmos.airflow.dag.DbtDag + description: Render a complete dbt project as an Airflow DAG + - module: cosmos.airflow.task_group.DbtTaskGroup + description: Render a dbt project as an Airflow Task Group diff --git a/.codespell-ignore-words b/.codespell-ignore-words new file mode 100644 index 0000000000..a57e6f064d --- /dev/null +++ b/.codespell-ignore-words @@ -0,0 +1 @@ +rin diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 7df1efa414..b29ff67eff 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -65,6 +65,7 @@ repos: - --exclude-file=tests/sample/manifest_model_version.json - --skip=**/manifest.json,**.min.js,**/manifest_source.json - -L connexion,aci + - --ignore-words=.codespell-ignore-words - repo: https://github.com/pre-commit/pygrep-hooks rev: v1.10.0 hooks: diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 081b0096dc..d185a28d02 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -1,6 +1,37 @@ Changelog ========= +1.13.1 (2026-02-25) +------------------- + +Enhancements + +* Change Snowflake profile mappings to default to four threads by @tatiana in #2374 +* Refactor to avoid potential future ``UnboundLocalError`` for ``producer_task`` in ``calculate_tasks_map`` by @rin in #2309 + +Bug Fixes + +* Fix graph selector when using + selector with ``dbt-loom`` by @award1230 in #2389 +* Populate ``compiled_sql`` for ``InvocationMode.SUBPROCESS`` in ``ExecutionMode.WATCHER`` by @pankajkoti in #2319 +* Preserve ``extra_context`` for watcher consumer task instances by @pankajkoti in #2381 +* Fix watcher: respect ``deferrable=False`` from ``operator_args`` on consumer sensor by @pankajkoti in #2384 +* Error handle invalid YAML with ``LoadMode.DBT_MANIFEST`` and ``RenderConfig.selector`` by @jonbillings in #2316 +* Fix selecting model when it has the same name as folder by @pankajastro in #2328 +* Handle Param Validation errors by @tatiana in #2358 +* Fix cache swap issue by @jonbillings in #2332 +* Fix leaked semaphore warnings in Airflow 3 by resetting dbt adapters by @pankajkoti in #2335 + +Docs + +* Document ``ExecutionMode.KUBERNETES`` limitations by @tatiana in #2326 + +Others + +* Add .airflow-registry.yaml for Airflow Provider Registry by @kaxil in #2387 +* Improve test coverage for PR #2307 by @tatiana in #2308 +* Address feedback from code review #2389 by @evanvolgas in #2394 + + 1.13.0 (2026-01-30) --------------------- diff --git a/README.rst b/README.rst index 3f28a81cf1..0fcc891338 100644 --- a/README.rst +++ b/README.rst @@ -1,3 +1,9 @@ +.. License badge +.. |license| image:: https://img.shields.io/:license-Apache%202-blue.svg + :target: https://www.apache.org/licenses/LICENSE-2.0.txt + :alt: License + +.. PyPI badges .. |fury| image:: https://badge.fury.io/py/astronomer-cosmos.svg :target: https://badge.fury.io/py/astronomer-cosmos diff --git a/cosmos/__init__.py b/cosmos/__init__.py index 6dac0e0f42..6b22a75da5 100644 --- a/cosmos/__init__.py +++ b/cosmos/__init__.py @@ -10,7 +10,7 @@ from cosmos import settings -__version__ = "1.13.0" +__version__ = "1.13.1" if not settings.enable_memory_optimised_imports: diff --git a/cosmos/airflow/graph.py b/cosmos/airflow/graph.py index 156fea563f..d149c0855f 100644 --- a/cosmos/airflow/graph.py +++ b/cosmos/airflow/graph.py @@ -288,7 +288,7 @@ def create_dbt_resource_to_class(test_behavior: TestBehavior) -> dict[str, str]: return dbt_resource_to_class -def create_task_metadata( +def create_task_metadata( # noqa: C901 node: DbtNode, execution_mode: ExecutionMode, args: dict[str, Any], @@ -340,7 +340,10 @@ def create_task_metadata( models_select_key = "models" if settings.pre_dbt_fusion else "select" if render_config.test_behavior == TestBehavior.BUILD and node.resource_type in SUPPORTED_BUILD_RESOURCES: - args[models_select_key] = f"{node.resource_name}" + if node.fqn and len(node.fqn) > 0: + args[models_select_key] = f"fqn:{'.'.join(node.fqn)}" + else: + args[models_select_key] = f"{node.resource_name}" if test_indirect_selection != TestIndirectSelection.EAGER: args["indirect_selection"] = test_indirect_selection.value args["on_warning_callback"] = on_warning_callback @@ -390,7 +393,10 @@ def create_task_metadata( args = {} return TaskMetadata(id=task_id, operator_class="airflow.operators.empty.EmptyOperator", arguments=args) else: # DbtResourceType.MODEL, DbtResourceType.SEED and DbtResourceType.SNAPSHOT - args[models_select_key] = node.resource_name + if node.fqn and len(node.fqn) > 0: + args[models_select_key] = f"fqn:{'.'.join(node.fqn)}" + else: + args[models_select_key] = node.resource_name task_id, args = _get_task_id_and_args( node=node, args=args, @@ -853,6 +859,7 @@ def build_airflow_graph( # noqa: C901 TODO: https://github.com/astronomer/astro """ tasks_map: dict[str, TaskGroup | BaseOperator] = {} task_or_group: TaskGroup | BaseOperator | None + producer_task: BaseOperator | None = None # Identify test nodes that should be run detached from the associated dbt resource nodes because they # have multiple parents @@ -947,8 +954,7 @@ def build_airflow_graph( # noqa: C901 TODO: https://github.com/astronomer/astro create_airflow_task_dependencies(nodes, tasks_map) - if execution_mode in (ExecutionMode.WATCHER, ExecutionMode.WATCHER_KUBERNETES): - setup_operator_args = getattr(execution_config, "setup_operator_args", None) or {} + if producer_task: _add_watcher_dependencies( dag=dag, producer_airflow_task=producer_task, diff --git a/cosmos/converter.py b/cosmos/converter.py index a3136eab17..902d1bdfd0 100644 --- a/cosmos/converter.py +++ b/cosmos/converter.py @@ -12,6 +12,7 @@ from typing import Any from warnings import warn +from airflow.exceptions import ParamValidationError from airflow.models.dag import DAG from airflow.models.param import Param @@ -448,7 +449,14 @@ def _store_cosmos_telemetry_metadata_on_dag( # noqa: C901 # Store metadata in dag.params which is preserved during serialization # Using a key that's unlikely to conflict with user params compressed_metadata = _compress_telemetry_metadata(metadata) - dag.params["__cosmos_telemetry_metadata__"] = Param(default=compressed_metadata, const=compressed_metadata) - logger.debug( - f"Stored compressed Cosmos telemetry metadata in DAG {dag.dag_id} params (original size: {len(str(metadata))} bytes, compressed: {len(compressed_metadata)} bytes)" - ) + stored_metadata = False + try: + dag.params["__cosmos_telemetry_metadata__"] = Param(default=compressed_metadata, const=compressed_metadata) + stored_metadata = True + except ParamValidationError as e: + logger.warning(f"Failed to store compressed Cosmos telemetry metadata in DAG {dag.dag_id} params: {e}") + + if stored_metadata: + logger.debug( + f"Stored compressed Cosmos telemetry metadata in DAG {dag.dag_id} params (original size: {len(str(metadata))} bytes, compressed: {len(compressed_metadata)} bytes)" + ) diff --git a/cosmos/dbt/graph.py b/cosmos/dbt/graph.py index 1696c3991b..b513ca39e0 100644 --- a/cosmos/dbt/graph.py +++ b/cosmos/dbt/graph.py @@ -95,6 +95,7 @@ class DbtNode: has_test: bool = False has_non_detached_test: bool = False downstream: list[str] = field(default_factory=lambda: []) + fqn: list[str] | None = None @property def meta(self) -> dict[str, Any]: @@ -342,6 +343,7 @@ def parse_dbt_ls_output(project_path: Path | None, ls_stdout: str) -> dict[str, if DbtResourceType(node_dict["resource_type"]) == DbtResourceType.SOURCE else False ), + fqn=node_dict.get("fqn"), ) except (KeyError, TypeError): logger.info("Could not parse following the dbt ls line even though it was a valid JSON `%s`", line) @@ -568,6 +570,10 @@ def get_dbt_ls_cache(self) -> dict[str, str]: if dbt_ls_compressed: encoded_data = base64.b64decode(dbt_ls_compressed.encode()) cache_dict["dbt_ls"] = zlib.decompress(encoded_data).decode() + else: + # Missing 'dbt_ls_compressed' key indicates wrong cache type or corrupted cache + # Return empty dict to trigger cache miss and force fresh dbt ls run + cache_dict = {} return cache_dict @@ -641,6 +647,7 @@ def run_dbt_ls( "tags", "config", "freshness", + "fqn", ] else: ls_command = [ @@ -996,8 +1003,11 @@ def get_yaml_selectors_cache(self) -> dict[str, Any]: parsed_selectors = json.loads(zlib.decompress(encoded_parsed).decode()) cache_dict["yaml_selectors"] = YamlSelectors(raw_selectors, parsed_selectors) - - return cache_dict + else: + # Missing selector keys indicates wrong cache type or corrupted cache + # Return empty dict to trigger cache miss and force fresh selector parsing + cache_dict = {} + return cache_dict def save_yaml_selectors_cache(self, yaml_selectors: YamlSelectors) -> None: """ @@ -1154,6 +1164,7 @@ def load_from_dbt_manifest(self) -> None: if DbtResourceType(node_dict["resource_type"]) == DbtResourceType.SOURCE else False ), + fqn=node_dict.get("fqn"), ) nodes[node.unique_id] = node diff --git a/cosmos/dbt/runner.py b/cosmos/dbt/runner.py index 345701ff31..be4f8f9e55 100644 --- a/cosmos/dbt/runner.py +++ b/cosmos/dbt/runner.py @@ -1,5 +1,6 @@ from __future__ import annotations +import gc import sys from collections.abc import Callable from functools import lru_cache @@ -61,6 +62,28 @@ def get_runner(callbacks: list[Callable] | None = None) -> dbtRunner: # type: i return _get_cached_dbt_runner() +def _cleanup_dbt_adapters() -> None: + """ + Reset dbt adapters to release semaphores. + + dbt adapters maintain internal state that holds onto + semaphores. Resetting the adapters after each dbt command combined with + garbage collection prevents "leaked semaphore objects" warnings. + + See: https://github.com/astronomer/astronomer-cosmos/issues/2334 + """ + try: + from dbt.adapters.factory import reset_adapters + + reset_adapters() + except ImportError: + pass + except (RuntimeError, KeyError, AttributeError): + logger.debug("Error resetting dbt adapters", exc_info=True) + + gc.collect() + + def run_command( command: list[str], env: dict[str, str], cwd: str, callbacks: list[Callable] | None = None, **kwargs: Any # type: ignore[type-arg] ) -> dbtRunnerResult: @@ -74,7 +97,13 @@ def run_command( with change_working_directory(cwd), environ(env): logger.info("Trying to run dbtRunner with:\n %s\n in %s", cli_args, cwd) runner = get_runner(callbacks=callbacks) - result = runner.invoke(cli_args) + try: + result = runner.invoke(cli_args) + finally: + # Reset dbt adapters to release semaphores (run on all exit paths) + # See: https://github.com/astronomer/astronomer-cosmos/issues/2334 + _cleanup_dbt_adapters() + return result diff --git a/cosmos/dbt/selector.py b/cosmos/dbt/selector.py index 2f01ee4e78..f61f8fa1fe 100644 --- a/cosmos/dbt/selector.py +++ b/cosmos/dbt/selector.py @@ -168,7 +168,11 @@ def select_node_precursors(self, nodes: dict[str, DbtNode], root_id: str, select new_generation: set[str] = set() for node_id in previous_generation: if node_id not in processed_nodes: - new_generation.update(set(nodes[node_id].depends_on)) + # When using dbt-loom for cross-project references, external nodes are filtered out + # during manifest loading but local nodes may still reference them in depends_on. + # Skip missing nodes to gracefully stop traversal at project boundaries. + if node_id in nodes: + new_generation.update(set(nodes[node_id].depends_on)) processed_nodes.add(node_id) selected_nodes.update(new_generation) previous_generation = new_generation @@ -507,6 +511,11 @@ def select_nodes_ids_by_intersection(self) -> set[str]: if self.config.graph_selectors: graph_selected_nodes = self.select_by_graph_operator() for node_id in graph_selected_nodes: + # When using dbt-loom for cross-project references, external nodes are filtered out + # during manifest loading but may be collected during graph traversal via depends_on. + # Skip these external node IDs that don't exist in the nodes dict. + if node_id not in self.nodes: + continue node = self.nodes[node_id] # Since the method below changes the tags of test nodes, it can lead to incorrect # results during the application of graph selectors. Therefore, it is being run within @@ -774,11 +783,22 @@ def get_parsed(self, selector_name: str, default: Any = None) -> dict[str, Any] :param selector_name: str - Name of the selector to retrieve :param default: Any - Default value to return if selector is not found :return: dict[str, Any] | Any - The parsed selector definition or the default value + :raises CosmosValueError: If the selector has parsing errors """ - return self.parsed.get(selector_name, default) + parsed_selector = self.parsed.get(selector_name, default) + + if parsed_selector and isinstance(parsed_selector, dict): + errors = parsed_selector.get("errors", []) + + if errors: + error_list = "\n - ".join(errors) + error_message = f"Error(s) parsing selector '{selector_name}':\n - {error_list}" + raise CosmosValueError(error_message) + + return parsed_selector @staticmethod - def _get_list_dicts(dct: dict[str, Any], key: str) -> list[dict[str, Any]]: + def _get_list_dicts(dct: dict[str, Any], key: str) -> tuple[list[dict[str, Any]], list[str]]: """ Extract and validate a list of dictionaries from a given key in a dictionary. @@ -787,63 +807,67 @@ def _get_list_dicts(dct: dict[str, Any], key: str) -> list[dict[str, Any]]: :param dct: dict[str, Any] - The dictionary to extract from :param key: str - The key to look up in the dictionary - :return: list[dict[str, Any]] - List of dictionaries extracted from the specified key - :raises CosmosValueError: If the key is missing, the value is not a list, or contains invalid types + :return: tuple[list[dict[str, Any]], list[str]] - Tuple of (list of dictionaries, list of error messages) """ - result: list[dict[str, Any]] = [] - if key not in dct: - raise CosmosValueError(f"Expected to find key '{key}' in dict, only found {list(dct)}") values = dct[key] if not isinstance(values, list): - raise CosmosValueError(f"Invalid value for key '{key}'. Expected a list.") + return ([], [f"Invalid value for key '{key}'. Expected a list."]) + + result: list[dict[str, Any]] = [] + errors: list[str] = [] + for value in values: if isinstance(value, dict): - for value_key in value: - if not isinstance(value_key, str): - raise CosmosValueError( - f'Expected all keys to "{key}" dict to be strings, ' - f'but "{value_key}" is a "{type(value_key)}"' - ) - result.append(value) + if all(isinstance(value_key, str) for value_key in value): + result.append(value) + else: + for value_key in value: + if not isinstance(value_key, str): + errors.append( + f'Expected all keys to "{key}" dict to be strings, ' + f'but "{value_key}" is a "{type(value_key)}"' + ) else: - raise CosmosValueError( - f'Invalid value type {type(value)} in key "{key}", expected dict (value: {value}).' - ) + errors.append(f'Invalid value type {type(value)} in key "{key}", expected dict (value: {value}).') - return result + return (result, errors) @staticmethod def _parse_selection_graph_operators( - selector: str, parents: bool, children: bool, parents_depth: int, children_depth: int, childrens_parents: bool - ) -> str: + selector: str | None, + parents: bool, + children: bool, + parents_depth: int, + children_depth: int, + childrens_parents: bool, + ) -> tuple[str | None, str | None]: """ Apply graph operator syntax to a selector string. Transforms a selector by adding graph operators (+, @) based on the specified parameters. These operators control traversal of the DAG to include parents, children, or both. - :param selector: str - The base selector string to modify + :param selector: str | None - The base selector string to modify (None if there was a parse error) :param parents: bool - Whether to include parent nodes :param children: bool - Whether to include child nodes :param parents_depth: int - Number of parent generations to include (0 for all) :param children_depth: int - Number of child generations to include (0 for all) :param childrens_parents: bool - Whether to use the @ operator (all ancestors of children) - :return: str - The selector string with graph operators applied - :raises CosmosValueError: If childrens_parents is combined with depth parameters + :return: tuple[str | None, str | None] - Tuple of (selector string with graph operators or None, error message or None) Examples: - - selector="my_model", parents=True, children=False -> "+my_model" - - selector="my_model", parents=True, children=True, parents_depth=2 -> "2+my_model+" - - selector="my_model", childrens_parents=True -> "@my_model" + - selector="my_model", parents=True, children=False -> ("+my_model", None) + - selector="my_model", parents=True, children=True, parents_depth=2 -> ("2+my_model+", None) + - selector="my_model", childrens_parents=True -> ("@my_model", None) """ if not selector: - return selector + return (selector, None) if childrens_parents and (parents_depth or children_depth): - raise CosmosValueError("childrens_parents cannot be combined with parents_depth or children_depth.") + return (selector, "childrens_parents cannot be combined with parents_depth or children_depth.") if childrens_parents: - return f"{AT_SELECTOR}{selector}" + return (f"{AT_SELECTOR}{selector}", None) if parents: prefix = f"{parents_depth}{PLUS_SELECTOR}" if parents_depth > 0 else PLUS_SELECTOR @@ -853,10 +877,10 @@ def _parse_selection_graph_operators( suffix = f"{PLUS_SELECTOR}{children_depth}" if children_depth > 0 else PLUS_SELECTOR selector = f"{selector}{suffix}" - return selector + return (selector, None) @staticmethod - def _parse_selection_from_cosmos_spec(method: str, value: str) -> str: + def _parse_selection_from_cosmos_spec(method: str, value: str) -> tuple[str | None, str | None]: """ Convert a dbt YAML selector method and value into Cosmos selector syntax. @@ -865,17 +889,16 @@ def _parse_selection_from_cosmos_spec(method: str, value: str) -> str: :param method: str - The selector method (e.g., "tag", "path", "config.materialized") :param value: str - The value for the selector method - :return: str - A Cosmos-formatted selector string - :raises CosmosValueError: If the method is not supported + :return: tuple[str | None, str | None] - Tuple of (selector string or None, error message or None) Examples: - - method="tag", value="nightly" -> "tag:nightly" - - method="path", value="models/" -> "path:models/" - - method="fqn", value="*" -> "" - - method="config.materialized", value="view" -> "config.materialized:view" + - method="tag", value="nightly" -> ("tag:nightly", None) + - method="path", value="models/" -> ("path:models/", None) + - method="fqn", value="*" -> ("", None) + - method="config.materialized", value="view" -> ("config.materialized:view", None) """ if method == "fqn": - return "" if value == "*" else value + return ("" if value == "*" else value, None) method_mappings = { TAG_SELECTOR[:-1]: TAG_SELECTOR, @@ -888,15 +911,15 @@ def _parse_selection_from_cosmos_spec(method: str, value: str) -> str: for method_prefix, selector_prefix in method_mappings.items(): if method == method_prefix: - return f"{selector_prefix}{value}" + return (f"{selector_prefix}{value}", None) if any(method.startswith(f"{CONFIG_SELECTOR}{config}") for config in SUPPORTED_CONFIG): - return f"{method}:{value}" + return (f"{method}:{value}", None) - raise CosmosValueError(f"Unsupported selector method: '{method}'") + return (None, f"Unsupported selector method: '{method}'") @classmethod - def _parse_selector(cls, dct: dict[str, Any]) -> str: + def _parse_selector(cls, dct: dict[str, Any]) -> tuple[str | None, list[str]]: """ Parse a single selector definition dictionary into a Cosmos selector string. @@ -905,8 +928,7 @@ def _parse_selector(cls, dct: dict[str, Any]) -> str: :param dct: dict[str, Any] - Dictionary containing selector definition with keys like 'method', 'value', 'parents', 'children', 'parents_depth', 'children_depth', 'childrens_parents' - :return: str - A Cosmos-formatted selector string - :raises CosmosValueError: If depth values are not integers + :return: tuple[str | None, list[str]] - Tuple of (selector string or None, list of error messages) Example Input: { @@ -917,7 +939,7 @@ def _parse_selector(cls, dct: dict[str, Any]) -> str: } Example Output: - "tag:nightly+2" + ("tag:nightly+2", []) """ method = dct["method"] value = dct["value"] @@ -927,23 +949,33 @@ def _parse_selector(cls, dct: dict[str, Any]) -> str: children_depth = dct.get("children_depth", 0) childrens_parents = dct.get("childrens_parents", False) + errors: list[str] = [] + if not isinstance(parents_depth, int): - raise CosmosValueError(f"parents_depth must be an integer, got {type(parents_depth)}: {parents_depth}") + errors.append(f"parents_depth must be an integer, got {type(parents_depth)}: {parents_depth}") if not isinstance(children_depth, int): - raise CosmosValueError(f"children_depth must be an integer, got {type(children_depth)}: {children_depth}") + errors.append(f"children_depth must be an integer, got {type(children_depth)}: {children_depth}") - selector = cls._parse_selection_from_cosmos_spec(method, value) + if errors: + return (None, errors) - selector = cls._parse_selection_graph_operators( + selector, error = cls._parse_selection_from_cosmos_spec(method, value) + if error: + errors.append(error) + return (None, errors) + + selector, error = cls._parse_selection_graph_operators( selector, parents, children, parents_depth, children_depth, childrens_parents ) + if error: + errors.append(error) - return selector + return (selector, errors) @classmethod def _parse_exclusions( cls, definition: dict[str, Any], cache: dict[str, tuple[list[str], list[str]]] | None = None - ) -> list[str]: + ) -> tuple[list[str], list[str]]: """ Parse exclusion definitions from a selector definition. @@ -952,25 +984,28 @@ def _parse_exclusions( :param definition: dict[str, Any] - Dictionary containing the selector definition with an 'exclude' key :param cache: dict[str, tuple[list[str], list[str]]] - Cache of previously parsed selectors for reference resolution - :return: list[str] - List of exclusion selector strings + :return: tuple[list[str], list[str]] - Tuple of (exclusion selector strings, error messages) """ cache = cache or {} - - exclusions = cls._get_list_dicts(definition, "exclude") exclude: list[str] = [] + errors: list[str] = [] + + exclusions, definition_errors = cls._get_list_dicts(definition, "exclude") + errors.extend(definition_errors) for exclusion in exclusions: - excl, _ = cls._parse_from_definition(exclusion, cache=cache) + excl, _, parse_errors = cls._parse_from_definition(exclusion, cache=cache) exclude.extend(excl) + errors.extend(parse_errors) - return exclude + return (exclude, errors) @classmethod def _parse_include_exclude_subdefs( cls, definitions: list[dict[str, Any]], cache: dict[str, tuple[list[str], list[str]]], - ) -> tuple[list[str], list[str]]: + ) -> tuple[list[str], list[str], list[str]]: """ Parse a list of selector subdefinitions into include and exclude lists. @@ -979,35 +1014,38 @@ def _parse_include_exclude_subdefs( :param definitions: list[dict[str, Any]] - List of selector definition dictionaries :param cache: dict[str, tuple[list[str], list[str]]] - Cache of previously parsed selectors for reference resolution - :return: tuple[list[str], list[str]] - Tuple of (include_list, exclude_list) containing selector strings - :raises CosmosValueError: If multiple exclude clauses are found at the same level + :return: tuple[list[str], list[str], list[str]] - Tuple of (include_list, exclude_list, error messages) """ include: list[str] = [] exclude: list[str] = [] + errors: list[str] = [] for definition in definitions: if isinstance(definition, dict) and "exclude" in definition: # Do not allow multiple exclude: defs at the same level if exclude: yaml_sel_cfg = yaml.dump(definition) - raise CosmosValueError( + errors.append( f"You cannot provide multiple exclude arguments to the " f"same selector set operator:\n{yaml_sel_cfg}" ) - definition_exclude = cls._parse_exclusions(definition, cache=cache) - exclude.extend(definition_exclude) + else: + definition_exclude, excl_errors = cls._parse_exclusions(definition, cache=cache) + exclude.extend(definition_exclude) + errors.extend(excl_errors) else: - definition_include, _ = cls._parse_from_definition(definition, cache=cache) + definition_include, _, parse_errors = cls._parse_from_definition(definition, cache=cache) include.extend(definition_include) + errors.extend(parse_errors) - return (include, exclude) + return (include, exclude, errors) @classmethod def _parse_union_definition( cls, definition: dict[str, Any], cache: dict[str, tuple[list[str], list[str]]] | None = None, - ) -> tuple[list[str], list[str]]: + ) -> tuple[list[str], list[str], list[str]]: """ Parse a union selector definition. @@ -1016,7 +1054,7 @@ def _parse_union_definition( :param definition: dict[str, Any] - Dictionary containing a 'union' key with a list of selector definitions :param cache: dict[str, tuple[list[str], list[str]]] - Cache of previously parsed selectors for reference resolution - :return: tuple[list[str], list[str]] - Tuple of (include_list, exclude_list) containing selector strings + :return: tuple[list[str], list[str], list[str]] - Tuple of (include_list, exclude_list, error messages) Example Input: { @@ -1027,19 +1065,25 @@ def _parse_union_definition( } Example Output: - (["tag:nightly", "path:models/staging"], []) + (["tag:nightly", "path:models/staging"], [], []) """ cache = cache or {} + errors: list[str] = [] - union_def_parts = cls._get_list_dicts(definition, "union") - return cls._parse_include_exclude_subdefs(union_def_parts, cache=cache) + union_def_parts, definition_errors = cls._get_list_dicts(definition, "union") + errors.extend(definition_errors) + + include, exclude, parse_errors = cls._parse_include_exclude_subdefs(union_def_parts, cache=cache) + errors.extend(parse_errors) + + return (include, exclude, errors) @classmethod def _parse_intersection_definition( cls, definition: dict[str, Any], cache: dict[str, tuple[list[str], list[str]]], - ) -> tuple[list[str], list[str]]: + ) -> tuple[list[str], list[str], list[str]]: """ Parse an intersection selector definition. @@ -1049,7 +1093,7 @@ def _parse_intersection_definition( :param definition: dict[str, Any] - Dictionary containing an 'intersection' key with a list of selector definitions :param cache: dict[str, tuple[list[str], list[str]]] - Cache of previously parsed selectors for reference resolution - :return: tuple[list[str], list[str]] - Tuple of (include_list, exclude_list) with intersected selectors joined by commas + :return: tuple[list[str], list[str], list[str]] - Tuple of (include_list, exclude_list, error messages) Example Input: { @@ -1060,21 +1104,26 @@ def _parse_intersection_definition( } Example Output: - (["tag:nightly,config.materialized:view"], []) + (["tag:nightly,config.materialized:view"], [], []) """ - intersection_def_parts = cls._get_list_dicts(definition, "intersection") - include, exclude = cls._parse_include_exclude_subdefs(intersection_def_parts, cache=cache) + errors: list[str] = [] - intersection = [",".join(include)] + intersection_def_parts, definition_errors = cls._get_list_dicts(definition, "intersection") + errors.extend(definition_errors) - return (intersection, exclude) + include, exclude, parse_errors = cls._parse_include_exclude_subdefs(intersection_def_parts, cache=cache) + errors.extend(parse_errors) + + intersection = [",".join(include)] if include else [] + + return (intersection, exclude, errors) @classmethod def _parse_method_definition( cls, definition: dict[str, Any], cache: dict[str, tuple[list[str], list[str]]], - ) -> tuple[list[str], list[str]]: + ) -> tuple[list[str], list[str], list[str]]: """ Parse a method-based selector definition. @@ -1083,9 +1132,7 @@ def _parse_method_definition( :param definition: dict[str, Any] - Dictionary containing 'method' and 'value' keys, and optionally 'exclude' :param cache: dict[str, tuple[list[str], list[str]]] - Cache of previously parsed selectors for reference resolution - :return: tuple[list[str], list[str]] - Tuple of (include_list, exclude_list) containing selector strings - :raises CosmosValueError: If 'method' or 'value' keys are missing, or if a referenced - selector doesn't exist + :return: tuple[list[str], list[str], list[str]] - Tuple of (include_list, exclude_list, error messages) Example Input (method-based): { @@ -1095,7 +1142,7 @@ def _parse_method_definition( } Example Output: - (["tag:nightly"], ["path:models/archived"]) + (["tag:nightly"], ["path:models/archived"], []) Example Input (selector reference): { @@ -1104,28 +1151,33 @@ def _parse_method_definition( } Example Output: - Returns the cached tuple of (include_list, exclude_list) from the referenced selector + Returns the cached tuple of (include_list, exclude_list, []) from the referenced selector """ exclude: list[str] = [] + errors: list[str] = [] if definition.get("method") == "selector": sel_def = definition.get("value") if sel_def not in cache: - raise CosmosValueError(f"Existing selector definition for {sel_def} not found.") - return cache[definition["value"]] + errors.append(f"Existing selector definition for {sel_def} not found.") + return ([], [], errors) + return (*cache[definition["value"]], []) elif "method" in definition and "value" in definition: dct = definition if "exclude" in definition: - exclude = cls._parse_exclusions(definition, cache=cache) + exclude, excl_errors = cls._parse_exclusions(definition, cache=cache) + errors.extend(excl_errors) dct = {k: v for k, v in dct.items() if k != "exclude"} else: - raise CosmosValueError(f"Expected 'method' and 'value' keys, but got {list(definition)}") + errors.append(f"Expected 'method' and 'value' keys, but got {list(definition)}") + return ([], [], errors) - selector = cls._parse_selector(dct) + selector, parse_errors = cls._parse_selector(dct) + errors.extend(parse_errors) include = [selector] if selector else [] - return (include, exclude) + return (include, exclude, errors) @classmethod def _parse_from_definition( @@ -1133,7 +1185,7 @@ def _parse_from_definition( definition: dict[str, Any], cache: dict[str, tuple[list[str], list[str]]], rootlevel: bool = False, - ) -> tuple[list[str], list[str]]: + ) -> tuple[list[str], list[str], list[str]]: """ Recursively parse a selector definition into include and exclude lists. @@ -1144,14 +1196,15 @@ def _parse_from_definition( :param definition: dict[str, Any] - The selector definition dictionary to parse :param cache: dict[str, tuple[list[str], list[str]]] - Cache of previously parsed selectors for reference resolution :param rootlevel: bool - Whether this is a root-level selector (affects validation rules) - :return: tuple[list[str], list[str]] - Tuple of (include_list, exclude_list) containing selector strings - :raises CosmosValueError: If the definition structure is invalid or uses unsupported formats + :return: tuple[list[str], list[str], list[str]] - Tuple of (include_list, exclude_list, error messages) Note: Root-level selectors can only have a single 'union' or 'intersection' key. CLI-style string definitions and key-value pairs are not supported - use method-based definitions instead. """ + errors: list[str] = [] + if ( isinstance(definition, dict) and ("union" in definition or "intersection" in definition) @@ -1159,29 +1212,34 @@ def _parse_from_definition( and len(definition) > 1 ): keys = ",".join(definition.keys()) - raise CosmosValueError( + errors.append( f"Only a single 'union' or 'intersection' key is allowed " f"in a root level selector definition; found {keys}." ) + return ([], [], errors) if "union" in definition: - select, exclude = cls._parse_union_definition(definition, cache=cache) + select, exclude, parse_errors = cls._parse_union_definition(definition, cache=cache) + errors.extend(parse_errors) elif "intersection" in definition: - select, exclude = cls._parse_intersection_definition(definition, cache=cache) + select, exclude, parse_errors = cls._parse_intersection_definition(definition, cache=cache) + errors.extend(parse_errors) elif "method" in definition: - select, exclude = cls._parse_method_definition(definition, cache=cache) + select, exclude, parse_errors = cls._parse_method_definition(definition, cache=cache) + errors.extend(parse_errors) else: # Rejects CLI-style string definitions (e.g., definition: "tag:my_tag") # These should be structured as method-value definitions instead # # Rejects key-value definitions (e.g., tag: my_tag, path: models/, etc.) # These should be structured as method-value definitions instead - raise CosmosValueError( + errors.append( f"Expected to find union, intersection, or method definition, instead " f"found {type(definition)}: {definition}" ) + return ([], [], errors) - return (select, exclude) + return (select, exclude, errors) @classmethod def parse(cls, selectors: dict[str, dict[str, Any]]) -> YamlSelectors: @@ -1213,6 +1271,8 @@ def parse(cls, selectors: dict[str, dict[str, Any]]) -> YamlSelectors: "exclude": None } } + + Note: An "errors" key will only be present if parsing errors occurred """ result: dict[str, dict[str, Any]] = {} cache: dict[str, tuple[list[str], list[str]]] = {} @@ -1229,12 +1289,14 @@ def parse(cls, selectors: dict[str, dict[str, Any]]) -> YamlSelectors: selector_name = definition["name"] selector_definition = definition["definition"] - select, exclude = cls._parse_from_definition(selector_definition, cache=cache, rootlevel=True) + select, exclude, parse_errors = cls._parse_from_definition(selector_definition, cache=cache, rootlevel=True) result[selector_name] = { "select": select if select else None, "exclude": exclude if exclude else None, } + if parse_errors: + result[selector_name]["errors"] = parse_errors cache[selector_name] = (select, exclude) diff --git a/cosmos/hooks/subprocess.py b/cosmos/hooks/subprocess.py index 2e12876625..8d016c887f 100644 --- a/cosmos/hooks/subprocess.py +++ b/cosmos/hooks/subprocess.py @@ -99,6 +99,9 @@ def pre_exec() -> None: last_line: str = "" assert self.sub_process.stdout is not None + + # Make cwd available to the callback `process_log_line` function via kwargs + kwargs.setdefault("project_dir", cwd) for line in self.sub_process.stdout: line = line.rstrip("\n") last_line = line diff --git a/cosmos/operators/_watcher/base.py b/cosmos/operators/_watcher/base.py index a4be25bbcc..f128e29b0c 100644 --- a/cosmos/operators/_watcher/base.py +++ b/cosmos/operators/_watcher/base.py @@ -1,6 +1,7 @@ import json import logging from datetime import datetime, timedelta +from pathlib import Path from typing import Any from airflow.exceptions import AirflowException @@ -27,6 +28,63 @@ logger = get_logger(__name__) +def _extract_compiled_sql( + project_dir: str, unique_id: str, node_path: str | None, resource_type: str | None +) -> str | None: + """ + Extract compiled SQL from the target directory for a given dbt node. + + Used by both the subprocess strategy (via store_dbt_resource_status_from_log) + and the node-event strategy (via DbtProducerWatcherOperator._handle_node_finished); + both consume from the same target/compiled layout under project_dir. + + Assumes inputs come from dbt (relative node_path, unique_id like model.package.name). + """ + if resource_type != "model" or not node_path: + return None + + package = unique_id.split(".", 2)[1] + compiled_sql_path = Path(project_dir) / "target" / "compiled" / package / "models" / node_path + if not compiled_sql_path.exists(): + logger.warning( + "Compiled sql path %s does not exist and hence the rendered template field compiled_sql for the model will not be populated", + compiled_sql_path, + ) + return None + return compiled_sql_path.read_text(encoding="utf-8").strip() or None + + +def _push_compiled_sql_for_model(task_instance: Any, unique_id: str, compiled_sql: str) -> None: + """ + Push compiled SQL for a model to XCom under the canonical key. + + Single place where both subprocess and node-event producer paths set compiled_sql. + Consumers (sensor poke and trigger) always read from this key. + """ + safe_xcom_push( + task_instance=task_instance, + key=f"{unique_id.replace('.', '__')}_compiled_sql", + value=compiled_sql, + ) + + +def store_compiled_sql_for_model( + task_instance: Any, + project_dir: str, + unique_id: str, + node_path: str | None, + resource_type: str | None, +) -> None: + """ + Read compiled SQL from the target directory and store to XCom if present. + + Single sequence used by both subprocess and node-event producer paths. + """ + compiled_sql = _extract_compiled_sql(project_dir, unique_id, node_path, resource_type) + if compiled_sql: + _push_compiled_sql_for_model(task_instance, unique_id, compiled_sql) + + def store_dbt_resource_status_from_log(line: str, extra_kwargs: Any) -> None: """ Parses a single line from dbt JSON logs and stores node status to Airflow XCom. @@ -54,6 +112,14 @@ def store_dbt_resource_status_from_log(line: str, extra_kwargs: Any) -> None: assert context is not None # Make MyPy happy safe_xcom_push(task_instance=context["ti"], key=f"{unique_id.replace('.', '__')}_status", value=node_status) + # Extract and push compiled_sql for models (centralised for both subprocess and node-event) + # compiled_sql is available for both success and failed models - it's compiled before execution + project_dir = extra_kwargs.get("project_dir") + if project_dir: + store_compiled_sql_for_model( + context["ti"], project_dir, unique_id, node_info.get("node_path"), node_info.get("resource_type") + ) + # Additionally, log the message from dbt logs log_info = log_line.get("info", {}) msg = log_info.get("msg") @@ -91,9 +157,13 @@ def __init__( **kwargs: Any, ) -> None: self.compiled_sql = "" - extra_context = kwargs.pop("extra_context") if "extra_context" in kwargs else {} kwargs.setdefault("priority_weight", CONSUMER_WATCHER_DEFAULT_PRIORITY_WEIGHT) kwargs.setdefault("weight_rule", WATCHER_TASK_WEIGHT_RULE) + + # We pop extra_context before super().__init__ because BaseSensorOperator does not accept it and + # would raise on unknown kwargs; we assign after so subclasses and user customisations can use it + # at runtime (e.g. in execute, poke, or callbacks). + extra_context = kwargs.pop("extra_context") if "extra_context" in kwargs else {} super().__init__( poke_interval=poke_interval, timeout=timeout, @@ -103,6 +173,8 @@ def __init__( profiles_dir=profiles_dir, **kwargs, ) + self.extra_context = extra_context + self.project_dir = project_dir self.producer_task_id = producer_task_id self.deferrable = deferrable @@ -175,8 +247,8 @@ def _get_status_from_run_results(self, ti: Any, context: Context) -> Any: logger.info("Node Info: %s", run_results_json) self.compiled_sql = node_result.get("compiled_code") - if self.compiled_sql: - self._override_rtif(context) # type: ignore[attr-defined] + if self.compiled_sql and hasattr(self, "_override_rtif"): + self._override_rtif(context) return node_result.get("status") @@ -230,6 +302,13 @@ def execute_complete(self, context: Context, event: dict[str, str]) -> None: self.model_unique_id, ) + # Extract and store compiled_sql from the event if available + compiled_sql = event.get("compiled_sql") + if compiled_sql: + self.compiled_sql = compiled_sql + if hasattr(self, "_override_rtif"): + self._override_rtif(context) + if status != "failed": return @@ -275,6 +354,16 @@ def poke(self, context: Context) -> bool: else: status = get_xcom_val(ti, self.producer_task_id, f"{self.model_unique_id.replace('.', '__')}_status") + # compiled_sql is always in the canonical per-model XCom key (same for event and subprocess modes) + if status is not None: + compiled_sql = get_xcom_val( + ti, self.producer_task_id, f"{self.model_unique_id.replace('.', '__')}_compiled_sql" + ) + if compiled_sql: + self.compiled_sql = compiled_sql + if hasattr(self, "_override_rtif"): + self._override_rtif(context) + if status is None: if producer_task_state == "failed": diff --git a/cosmos/operators/_watcher/triggerer.py b/cosmos/operators/_watcher/triggerer.py index d232c08322..96ffa926fc 100644 --- a/cosmos/operators/_watcher/triggerer.py +++ b/cosmos/operators/_watcher/triggerer.py @@ -98,22 +98,32 @@ async def get_xcom_val(self, key: str) -> Any | None: else: return await self.get_xcom_val_af3(key) - async def _parse_node_status(self) -> str | None: - key = ( + async def _parse_node_status_and_compiled_sql(self) -> tuple[str | None, str | None]: + """ + Parse node status and compiled_sql from XCom. + + Returns a tuple of (status, compiled_sql). + Status comes from mode-specific keys (nodefinished_* for event, *_status for subprocess). + compiled_sql is always read from the canonical per-model key (same for both modes). + """ + status_key = ( f"nodefinished_{self.model_unique_id.replace('.', '__')}" if self.use_event else f"{self.model_unique_id.replace('.', '__')}_status" ) + compiled_sql_key = f"{self.model_unique_id.replace('.', '__')}_compiled_sql" if self.use_event: - compressed_xcom_val = await self.get_xcom_val(key) + compressed_xcom_val = await self.get_xcom_val(status_key) if not compressed_xcom_val: - return None - + return None, None data_json = _parse_compressed_xcom(compressed_xcom_val) - return data_json.get("data", {}).get("run_result", {}).get("status") # type: ignore[no-any-return] + status = data_json.get("data", {}).get("run_result", {}).get("status") + else: + status = await self.get_xcom_val(status_key) - return await self.get_xcom_val(key) + compiled_sql = await self.get_xcom_val(compiled_sql_key) if status is not None else None + return status, compiled_sql async def _get_producer_task_status(self) -> str | None: """Retrieve the producer task state for both Airflow 2 and Airflow 3.""" @@ -135,14 +145,20 @@ async def run(self) -> AsyncIterator[TriggerEvent]: while True: producer_task_state = await self._get_producer_task_status() - node_status = await self._parse_node_status() + node_status, compiled_sql = await self._parse_node_status_and_compiled_sql() if node_status == "success": logger.info("Model '%s' succeeded", self.model_unique_id) - yield TriggerEvent({"status": "success"}) # type: ignore[no-untyped-call] + event_data: dict[str, Any] = {"status": "success"} + if compiled_sql: + event_data["compiled_sql"] = compiled_sql + yield TriggerEvent(event_data) # type: ignore[no-untyped-call] return elif node_status == "failed": logger.warning("Model '%s' failed", self.model_unique_id) - yield TriggerEvent({"status": "failed", "reason": "model_failed"}) # type: ignore[no-untyped-call] + event_data = {"status": "failed", "reason": "model_failed"} + if compiled_sql: + event_data["compiled_sql"] = compiled_sql + yield TriggerEvent(event_data) # type: ignore[no-untyped-call] return elif producer_task_state == "failed": logger.error( diff --git a/cosmos/operators/watcher.py b/cosmos/operators/watcher.py index a999ea54e6..aed13ed76f 100644 --- a/cosmos/operators/watcher.py +++ b/cosmos/operators/watcher.py @@ -5,7 +5,6 @@ import zlib from collections.abc import Callable, Sequence from datetime import timedelta -from pathlib import Path from typing import TYPE_CHECKING, Any from airflow.exceptions import AirflowException @@ -26,7 +25,11 @@ InvocationMode, ) from cosmos.log import get_logger -from cosmos.operators._watcher.base import BaseConsumerSensor, store_dbt_resource_status_from_log +from cosmos.operators._watcher.base import ( + BaseConsumerSensor, + store_compiled_sql_for_model, + store_dbt_resource_status_from_log, +) from cosmos.operators.base import ( DbtBuildMixin, DbtRunMixin, @@ -114,21 +117,6 @@ def _handle_startup_event(self, event_message: EventMsg, startup_events: list[di ts_val = raw_ts.ToJsonString() if hasattr(raw_ts, "ToJsonString") else str(raw_ts) # type: ignore[union-attr] startup_events.append({"name": info.name, "msg": info.msg, "ts": ts_val}) - def _extract_compiled_sql_for_node_event(self, event_message: EventMsg) -> str | None: - if getattr(event_message.data.node_info, "resource_type", None) != "model": - return None - uid = event_message.data.node_info.unique_id - node_path = str(event_message.data.node_info.node_path) - package = uid.split(".")[1] - compiled_sql_path = Path.cwd() / "target" / "compiled" / package / "models" / node_path - if not compiled_sql_path.exists(): - logger.warning( - "Compiled sql path %s does not exist and hence the rendered template field compiled_sql for the model will not be populated", - compiled_sql_path, - ) - return None - return compiled_sql_path.read_text(encoding="utf-8").strip() or None - def _handle_node_finished( self, event_message: EventMsg, @@ -136,10 +124,11 @@ def _handle_node_finished( ) -> None: logger.debug("DbtProducerWatcherOperator: handling node finished event: %s", event_message) uid = event_message.data.node_info.unique_id + node_path_val = getattr(event_message.data.node_info, "node_path", None) + node_path = str(node_path_val) if node_path_val is not None else None + resource_type = getattr(event_message.data.node_info, "resource_type", None) event_message_dict = self._serialize_event(event_message) - compiled_sql = self._extract_compiled_sql_for_node_event(event_message) - if compiled_sql: - event_message_dict["compiled_sql"] = compiled_sql + store_compiled_sql_for_model(context["ti"], self.project_dir, uid, node_path, resource_type) payload = base64.b64encode(zlib.compress(json.dumps(event_message_dict).encode())).decode() safe_xcom_push(task_instance=context["ti"], key=f"nodefinished_{uid.replace('.', '__')}", value=payload) @@ -247,6 +236,8 @@ def __init__( profile_config=profile_config, project_dir=project_dir, profiles_dir=profiles_dir, + producer_task_id=producer_task_id, + deferrable=deferrable, **kwargs, ) @@ -266,10 +257,6 @@ def _get_status_from_events(self, ti: Any, context: Context) -> Any: logger.info("Node Info: %s", event_json) - self.compiled_sql = event_json.get("compiled_sql", "") - if self.compiled_sql: - self._override_rtif(context) - return event_json.get("data", {}).get("run_result", {}).get("status") def use_event(self) -> bool: diff --git a/cosmos/profiles/snowflake/base.py b/cosmos/profiles/snowflake/base.py index 8a31486d02..97db403d58 100644 --- a/cosmos/profiles/snowflake/base.py +++ b/cosmos/profiles/snowflake/base.py @@ -7,15 +7,19 @@ from cosmos.profiles.base import BaseProfileMapping DEFAULT_AWS_REGION = "us-west-2" +DEFAULT_THREADS = 4 class SnowflakeBaseProfileMapping(BaseProfileMapping): + profile_defaults: dict[str, Any] = {"threads": DEFAULT_THREADS} + @property def profile(self) -> dict[str, Any | None]: """Gets profile.""" profile_vars = { **self.mapped_params, + **self.profile_defaults, **self.profile_args, } return profile_vars diff --git a/docs/configuration/caching.rst b/docs/configuration/caching.rst index db2ce87039..7289d00742 100644 --- a/docs/configuration/caching.rst +++ b/docs/configuration/caching.rst @@ -172,7 +172,7 @@ As an example, the following clean-up DAG will delete any cache associated with **Cache key** -The Airflow variables that represent the dbt ls cache are prefixed by ``cosmos_cache``. +The Airflow variables that represent the yaml selectors cache are prefixed by ``cosmos_cache``. When using ``DbtDag``, the keys use the DAG name. When using ``DbtTaskGroup``, they contain the ``TaskGroup`` and parent task groups and DAG. Examples: diff --git a/docs/configuration/scheduling.rst b/docs/configuration/scheduling.rst index 622803176c..3c9e607266 100644 --- a/docs/configuration/scheduling.rst +++ b/docs/configuration/scheduling.rst @@ -92,7 +92,6 @@ Will trigger the following DAG to be run (when using Cosmos 1.1 when using Airfl from airflow.datasets import Dataset from airflow.operators.empty import EmptyOperator - with DAG( "dataset_triggered_dag", description="A DAG that should be triggered via Dataset", @@ -118,7 +117,6 @@ From Cosmos 1.7 and Airflow 2.10, it is also possible to trigger DAGs be to be r from airflow.datasets import DatasetAlias from airflow.operators.empty import EmptyOperator - with DAG( "datasetalias_triggered_dag", description="A DAG that should be triggered via Dataset alias", diff --git a/docs/configuration/selecting-excluding.rst b/docs/configuration/selecting-excluding.rst index 658f59c48f..932b924ad2 100644 --- a/docs/configuration/selecting-excluding.rst +++ b/docs/configuration/selecting-excluding.rst @@ -180,3 +180,26 @@ Parsing of the ``default`` and ``indirect_selection`` keywords is not currently In the event the dbt YAML selector specification changes, Cosmos will attempt to keep up to date with the changes, but there may be a lag between dbt releases and Cosmos releases. Once a new Cosmos version is released with the updated selector parsing logic, users should update their Cosmos version to ensure compatibility with the latest dbt selector specification. For subsequent updates to the YAML selector parser, existing YAML selector caches will be invalidated the next time the DAG is parsed. + +**Error Handling** + +Cosmos distinguishes between two types of errors when parsing YAML selectors: + +- **Structural YAML Errors** - These cause immediate failure during manifest parsing: + + - Selector definition is not a dictionary + - Missing required ``name`` key + - Missing required ``definition`` key + + These errors indicate malformed YAML structure and will raise a ``CosmosValueError`` immediately when calling ``YamlSelectors.parse()``. + +- **Selector Definition Errors** - These are isolated and surfaced when accessing the selector: + + - Unsupported selector methods (e.g., ``method: "state"``, ``method: "package"``) + - Invalid graph operator configurations (e.g., non-integer depth values) + - Invalid selector logic (e.g., multiple root keys in a definition) + + These errors are collected during parsing but only raised when you attempt to retrieve the selector using ``get_parsed(selector_name)``. + This allows the manifest to be loaded successfully even if some selectors have definition errors, enabling you to work with valid selectors while debugging invalid ones. + +If a selector has multiple definition errors, they will all be reported together in a formatted error message when accessing the selector. diff --git a/docs/configuration/testing-behavior.rst b/docs/configuration/testing-behavior.rst index fdd4fd47fb..4e18765be9 100644 --- a/docs/configuration/testing-behavior.rst +++ b/docs/configuration/testing-behavior.rst @@ -172,7 +172,6 @@ The test will only run once after both models run, leading the DAG to succeed: from cosmos import DbtDag, RenderConfig - example_multiple_parents_test = DbtDag( ..., render_config=RenderConfig( diff --git a/docs/generate_mappings.py b/docs/generate_mappings.py index f04f776478..52a7b1a787 100644 --- a/docs/generate_mappings.py +++ b/docs/generate_mappings.py @@ -72,6 +72,7 @@ def generate_mapping_docs( "mapping_description": "\n\n".join(docstring.split("\n")), "fields": [field.__dict__ for field in get_fields_from_mapping(mapping=mapping)], "airflow_conn_type": mapping.airflow_connection_type, + "profile_defaults": getattr(mapping, "profile_defaults", {}), } ) ) diff --git a/docs/getting_started/execution-modes-local-conflicts.rst b/docs/getting_started/execution-modes-local-conflicts.rst index 8dbcf01130..9fec173751 100644 --- a/docs/getting_started/execution-modes-local-conflicts.rst +++ b/docs/getting_started/execution-modes-local-conflicts.rst @@ -109,7 +109,6 @@ The table was created by running `nox `__ wi import nox - nox.options.sessions = ["compatibility"] nox.options.reuse_existing_virtualenvs = True diff --git a/docs/getting_started/kubernetes.rst b/docs/getting_started/kubernetes.rst index 34bbefb3a4..2dd2385e4c 100644 --- a/docs/getting_started/kubernetes.rst +++ b/docs/getting_started/kubernetes.rst @@ -149,3 +149,19 @@ Enable and trigger a run of the `jaffle_shop_k8s `_ to address this) +- Does not emit Airflow datasets, assets, and dataset aliases (there is an `open ticket `_ to address this) +- Does not handle installing dbt deps for users (there is an `open ticket `_ to address this) +- Does not support `ProfileMapping `_ (there is an `open ticket `_ to address this) +- Does not support `Callbacks `_ (there is an `open ticket `_ to address this) +- Does not expose Compiled SQL as a `templated field `_ +- Does not benefit from `Cosmos caching mechanisms `_ +- Does not support `generating dbt docs & uploading to an object store `_ (there is a `PR `_ to solve this for S3) diff --git a/docs/getting_started/watcher-kubernetes-execution-mode.rst b/docs/getting_started/watcher-kubernetes-execution-mode.rst index 5e882ba620..16dbbffd0a 100644 --- a/docs/getting_started/watcher-kubernetes-execution-mode.rst +++ b/docs/getting_started/watcher-kubernetes-execution-mode.rst @@ -174,6 +174,8 @@ The following limitations from ``ExecutionMode.WATCHER`` also apply to ``Executi For more details on these limitations, refer to the :ref:`watcher-execution-mode` documentation. +Additionally, the limitations from ``ExecutionMode.KUBERNETES`` also apply to ``ExecutionMode.WATCHER_KUBERNETES``. For details, refer to the :ref:`kubernetes-known-limitations` documentation. + ------------------------------------------------------------------------------- Example DAG diff --git a/docs/templates/profile_mapping.rst.jinja2 b/docs/templates/profile_mapping.rst.jinja2 index ff9ad8da17..c5b25b48b1 100644 --- a/docs/templates/profile_mapping.rst.jinja2 +++ b/docs/templates/profile_mapping.rst.jinja2 @@ -54,3 +54,22 @@ Some notes about the table above: For example, if the Airflow field name is ``['password', 'extra.token']``, the profile mapping will first look for a field named ``password``. If that field is not present, it will look for ``extra.token``. +{% if profile_defaults %} + +Default Values +-------------- + +This profile mapping sets the following default values. These can be overridden by passing +them in ``profile_args``. + +.. list-table:: + :header-rows: 1 + + * - Field Name + - Default Value + + {% for key, value in profile_defaults.items() %} + * - ``{{ key }}`` + - ``{{ value }}`` + {% endfor %} +{% endif %} diff --git a/tests/dbt/test_graph.py b/tests/dbt/test_graph.py index c8c810871d..c87ea56a78 100644 --- a/tests/dbt/test_graph.py +++ b/tests/dbt/test_graph.py @@ -2208,7 +2208,7 @@ def test_save_yaml_selectors_cache(mock_variable_set, mock_datetime, tmp_dbt_pro hash_dir, hash_selectors, hash_impl = version.split(",") assert hash_selectors == "43303af03e84e3b51fbfcf598261fae4" - assert hash_impl == "3ae7ccd90b387308920fa408907de75d" + assert hash_impl == "c7ba8b331e80ae876d5c7c7c1cfdf93d" if sys.platform == "darwin": # We faced inconsistent hashing versions depending on the version of MacOS/Linux - the following line aims to address these. @@ -2377,7 +2377,59 @@ def test_should_use_yaml_selectors_cache(enable_cache, enable_cache_yaml_selecto assert graph.should_use_yaml_selectors_cache() == should_use -@pytest.mark.skipif(not AIRFLOW_IO_AVAILABLE, reason="Airflow did not have Object Storage until the 2.8 release") +@patch("cosmos.dbt.graph.DbtGraph.should_use_dbt_ls_cache", return_value=True) +@patch("cosmos.dbt.graph.DbtGraph.should_use_yaml_selectors_cache", return_value=True) +@patch("cosmos.dbt.graph.Variable.get") +def test_cache_miss_when_loading_dbt_ls_cache_as_yaml_selectors_cache( + mock_variable_get, mock_should_use_yaml_selectors_cache, mock_should_use_dbt_ls_cache, tmp_dbt_project_dir +): + """ + Test that loading a dbt ls cache as a yaml selectors cache causes a cache miss. + + This ensures that when both cache types use the same Airflow Variable key, attempting to load + a dbt ls cache as a yaml selectors cache will fail gracefully and return a cache miss instead of corrupted data. + """ + graph = DbtGraph(cache_identifier="test_swap", project=ProjectConfig(dbt_project_path=tmp_dbt_project_dir)) + + dbt_ls_cache_data = { + "version": "hash_dir,hash_args", # dbt ls version format (2 parts) + "dbt_ls_compressed": "eJwrzs9NVcgvLSkoLQEAGpAEhg==", + "last_modified": "2022-01-01T12:00:00", + } + mock_variable_get.return_value = dbt_ls_cache_data + + yaml_cache_result = graph.get_yaml_selectors_cache() + + assert yaml_cache_result == {}, "Expected cache miss when loading dbt ls cache as yaml selectors cache" + + +@patch("cosmos.dbt.graph.DbtGraph.should_use_dbt_ls_cache", return_value=True) +@patch("cosmos.dbt.graph.DbtGraph.should_use_yaml_selectors_cache", return_value=True) +@patch("cosmos.dbt.graph.Variable.get") +def test_cache_miss_when_loading_yaml_selectors_cache_as_dbt_ls_cache( + mock_variable_get, mock_should_use_yaml_selectors_cache, mock_should_use_dbt_ls_cache, tmp_dbt_project_dir +): + """ + Test that loading a yaml selectors cache as a dbt ls cache causes a cache miss. + + This ensures that when both cache types use the same Airflow Variable key, attempting to load + a yaml selectors cache as a dbt ls cache will fail gracefully and return a cache miss instead of corrupted data. + """ + graph = DbtGraph(cache_identifier="test_swap", project=ProjectConfig(dbt_project_path=tmp_dbt_project_dir)) + + yaml_selectors_cache_data = { + "version": "hash_dir,hash_selectors,hash_impl", # yaml selectors version format (3 parts) + "raw_selectors_compressed": "eJyrViouSUzPzEuPzy9KSS0qVrJSqFZKSU3LzMssyczPA3NzU0sy8lOATCWgUiUdBaWyxJzSVCg/PlGpFiiUl5gLFkEzrbYWAFRnILk=", + "parsed_selectors_compressed": "eJyrVkqtSM4pTUlVslLIK83J0VFQKk7NSU0uAfKjlUoS062AOD5RKbYWADB2DhQ=", + "last_modified": "2022-01-01T12:00:00", + } + mock_variable_get.return_value = yaml_selectors_cache_data + + dbt_ls_cache_result = graph.get_dbt_ls_cache() + + assert dbt_ls_cache_result == {}, "Expected cache miss when loading yaml selectors cache as dbt ls cache" + + @patch(object_storage_path) @patch("cosmos.config.ProjectConfig") @patch("cosmos.dbt.graph._configure_remote_cache_dir") @@ -2635,9 +2687,9 @@ def test__normalize_path(): "pre_dbt_fusion_value,source_rendering_behaviour_value,expected_args_count", [ (True, SourceRenderingBehavior.NONE, 4), - (False, SourceRenderingBehavior.NONE, 13), - (True, SourceRenderingBehavior.ALL, 13), - (False, SourceRenderingBehavior.ALL, 13), + (False, SourceRenderingBehavior.NONE, 14), + (True, SourceRenderingBehavior.ALL, 14), + (False, SourceRenderingBehavior.ALL, 14), ], ) @patch("cosmos.dbt.graph.settings") diff --git a/tests/dbt/test_runner.py b/tests/dbt/test_runner.py index ea17592ed7..a696b9954d 100644 --- a/tests/dbt/test_runner.py +++ b/tests/dbt/test_runner.py @@ -2,8 +2,9 @@ import shutil import sys import tempfile +import types from pathlib import Path -from unittest.mock import patch +from unittest.mock import MagicMock, patch from airflow import DAG from pendulum import datetime @@ -58,6 +59,99 @@ def test_is_available_is_false(): assert not dbt_runner.is_available() +def test_cleanup_dbt_adapters_calls_reset_adapters_and_gc(): + """_cleanup_dbt_adapters calls reset_adapters when available and always runs gc.collect().""" + factory_mock = MagicMock() + with ( + patch("cosmos.dbt.runner.gc") as mock_gc, + patch.dict("sys.modules", {"dbt.adapters.factory": factory_mock}), + ): + dbt_runner._cleanup_dbt_adapters() + factory_mock.reset_adapters.assert_called_once() + mock_gc.collect.assert_called_once() + + +def test_cleanup_dbt_adapters_handles_import_error(): + """_cleanup_dbt_adapters does not raise when dbt.adapters.factory is not available.""" + + # Use a fake module in sys.modules so the "from dbt.adapters.factory import ..." fails + # with ImportError without patching builtins.__import__ globally. + class FakeFactoryModule(types.ModuleType): + def __getattr__(self, name: str): + raise ImportError("No module named 'dbt.adapters.factory'") + + fake_factory = FakeFactoryModule("dbt.adapters.factory") + with ( + patch("cosmos.dbt.runner.gc") as mock_gc, + patch.dict("sys.modules", {"dbt.adapters.factory": fake_factory}), + ): + dbt_runner._cleanup_dbt_adapters() + mock_gc.collect.assert_called_once() + + +def test_cleanup_dbt_adapters_handles_reset_exception(): + """_cleanup_dbt_adapters catches exceptions from reset_adapters and still runs gc.collect().""" + factory_mock = MagicMock() + factory_mock.reset_adapters.side_effect = RuntimeError("adapter error") + with ( + patch("cosmos.dbt.runner.gc") as mock_gc, + patch("cosmos.dbt.runner.logger") as mock_logger, + patch.dict("sys.modules", {"dbt.adapters.factory": factory_mock}), + ): + dbt_runner._cleanup_dbt_adapters() + mock_gc.collect.assert_called_once() + mock_logger.debug.assert_called_once_with("Error resetting dbt adapters", exc_info=True) + + +def test_run_command_calls_cleanup_dbt_adapters(): + """run_command calls _cleanup_dbt_adapters after runner.invoke to release semaphores.""" + fake_result = MagicMock() + fake_result.success = True + fake_result.exception = None + fake_result.result = None + + fake_runner = MagicMock() + fake_runner.invoke.return_value = fake_result + + with ( + patch.object(dbt_runner, "get_runner", return_value=fake_runner), + patch.object(dbt_runner, "_cleanup_dbt_adapters") as mock_cleanup, + patch.object(dbt_runner, "change_working_directory"), + patch.object(dbt_runner, "environ"), + patch.object(dbt_runner, "logger"), + ): + result = dbt_runner.run_command( + command=["dbt", "deps"], + env={}, + cwd="/tmp/project", + ) + assert result is fake_result + fake_runner.invoke.assert_called_once() + mock_cleanup.assert_called_once() + + +def test_run_command_calls_cleanup_dbt_adapters_when_invoke_raises(): + """run_command calls _cleanup_dbt_adapters even when runner.invoke raises (try/finally).""" + fake_runner = MagicMock() + fake_runner.invoke.side_effect = RuntimeError("invoke failed") + + with ( + patch.object(dbt_runner, "get_runner", return_value=fake_runner), + patch.object(dbt_runner, "_cleanup_dbt_adapters") as mock_cleanup, + patch.object(dbt_runner, "change_working_directory"), + patch.object(dbt_runner, "environ"), + patch.object(dbt_runner, "logger"), + ): + with pytest.raises(RuntimeError, match="invoke failed"): + dbt_runner.run_command( + command=["dbt", "deps"], + env={}, + cwd="/tmp/project", + ) + fake_runner.invoke.assert_called_once() + mock_cleanup.assert_called_once() + + @pytest.mark.integration def test_is_available_is_true(): assert dbt_runner.is_available() diff --git a/tests/dbt/test_selector.py b/tests/dbt/test_selector.py index 1dc951f396..22ea6f96c6 100644 --- a/tests/dbt/test_selector.py +++ b/tests/dbt/test_selector.py @@ -507,6 +507,44 @@ def test_select_node_by_child_and_precursors_no_node(): assert list(selected.keys()) == expected +def test_select_nodes_by_precursors_with_external_dependency(): + """Test that the + selector handles depends_on references to nodes not in the nodes dict. + + When using dbt-loom for cross-project references, external nodes are filtered out during + manifest loading (they have no file path). However, local nodes may still have depends_on + entries pointing to these external nodes. The + selector should gracefully skip missing + nodes instead of raising a KeyError. + """ + external_upstream_id = "model.upstream_project.external_model" + + local_staging = DbtNode( + unique_id=f"{DbtResourceType.MODEL.value}.{SAMPLE_PROJ_PATH.stem}.local_staging", + resource_type=DbtResourceType.MODEL, + depends_on=[external_upstream_id], + file_path=SAMPLE_PROJ_PATH / "models/local_staging.sql", + tags=[], + config={}, + ) + local_marts = DbtNode( + unique_id=f"{DbtResourceType.MODEL.value}.{SAMPLE_PROJ_PATH.stem}.local_marts", + resource_type=DbtResourceType.MODEL, + depends_on=[local_staging.unique_id], + file_path=SAMPLE_PROJ_PATH / "models/local_marts.sql", + tags=[], + config={}, + ) + + nodes_with_external_dep = { + local_staging.unique_id: local_staging, + local_marts.unique_id: local_marts, + } + + selected = select_nodes(project_dir=SAMPLE_PROJ_PATH, nodes=nodes_with_external_dep, select=["+local_marts"]) + assert local_marts.unique_id in selected + assert local_staging.unique_id in selected + assert external_upstream_id not in selected + + def test_select_node_by_descendants(): selected = select_nodes(project_dir=SAMPLE_PROJ_PATH, nodes=sample_nodes, select=["grandparent+"]) expected = [ @@ -1421,11 +1459,12 @@ def test_valid_graph_operator_yaml_selectors(selector_name, selector_definition, ) def test_invalid_cosmos_method_yaml_selectors(selector_name, selector_definition, exception_msg): selectors = {selector_name: selector_definition} + yaml_selectors = YamlSelectors.parse(selectors) with pytest.raises(CosmosValueError) as err_info: - _ = YamlSelectors.parse(selectors) + _ = yaml_selectors.get_parsed(selector_name) - assert err_info.value.args[0] == exception_msg + assert exception_msg in str(err_info.value) @pytest.mark.parametrize( @@ -1541,6 +1580,124 @@ def test_invalid_cosmos_method_yaml_selectors(selector_name, selector_definition {"name": "unknown_random_method", "definition": {"method": "unknown_random", "value": "test"}}, "Unsupported selector method: 'unknown_random'", ), + ( + "multiple_errors_non_int_depths", + { + "name": "multiple_errors_non_int_depths", + "definition": { + "method": "tag", + "value": "nightly", + "parents": True, + "parents_depth": "invalid", + "children": True, + "children_depth": "also_invalid", + }, + }, + ["parents_depth must be an integer", "children_depth must be an integer"], + ), + ( + "union_with_multiple_errors", + { + "name": "union_with_multiple_errors", + "definition": { + "union": [ + {"method": "unsupported_method1", "value": "test1"}, + {"method": "unsupported_method2", "value": "test2"}, + {"method": "tag", "value": "valid"}, + ] + }, + }, + [ + "Unsupported selector method: 'unsupported_method1'", + "Unsupported selector method: 'unsupported_method2'", + ], + ), + ( + "intersection_with_multiple_errors", + { + "name": "intersection_with_multiple_errors", + "definition": { + "intersection": [ + {"method": "tag", "value": "nightly", "parents": True, "parents_depth": "invalid1"}, + {"method": "path", "value": "models/", "children": True, "children_depth": "invalid2"}, + ] + }, + }, + ["parents_depth must be an integer", "children_depth must be an integer"], + ), + ( + "nested_errors", + { + "name": "nested_errors", + "definition": { + "union": [ + { + "method": "tag", + "value": "nightly", + "parents": True, + "parents_depth": "not_an_int", + "childrens_parents": True, + }, + {"method": "unknown_method", "value": "test"}, + ] + }, + }, + ["parents_depth must be an integer", "Unsupported selector method: 'unknown_method'"], + ), + ( + "errors_in_include_and_exclude", + { + "name": "errors_in_include_and_exclude", + "definition": { + "method": "unknown_include_method", + "value": "test", + "exclude": [{"method": "unknown_exclude_method", "value": "test2"}], + }, + }, + [ + "Unsupported selector method: 'unknown_include_method'", + "Unsupported selector method: 'unknown_exclude_method'", + ], + ), + ( + "multiple_errors_complex", + { + "name": "multiple_errors_complex", + "definition": { + "union": [ + {"method": "tag", "value": "nightly", "exclude": [{"method": "tag", "value": "deprecated"}]}, + {"method": "path", "value": "models/", "exclude": [{"method": "tag", "value": "archived"}]}, + {"method": "invalid_method", "value": "test"}, + ] + }, + }, + ["You cannot provide multiple exclude arguments", "Unsupported selector method: 'invalid_method'"], + ), + ( + "list_and_method_errors", + { + "name": "list_and_method_errors", + "definition": {"union": [{"method": "unknown_method", "value": "test"}, "not_a_dict"]}, + }, + ["Unsupported selector method: 'unknown_method'", "Invalid value type", "not_a_dict"], + ), + ], +) +def test_invalid_dbt_method_yaml_selectors(selector_name, selector_definition, exception_msg): + selectors = {selector_name: selector_definition} + yaml_selectors = YamlSelectors.parse(selectors) + + with pytest.raises(CosmosValueError) as err_info: + _ = yaml_selectors.get_parsed(selector_name) + + error_message = str(err_info.value) + expected_messages = exception_msg if isinstance(exception_msg, list) else [exception_msg] + assert all(msg in error_message for msg in expected_messages) + + +@pytest.mark.parametrize( + "selector_name, selector_definition, exception_msg", + [ ( "non_dict_definition", "not_a_dict", @@ -1558,13 +1715,14 @@ def test_invalid_cosmos_method_yaml_selectors(selector_name, selector_definition ), ], ) -def test_invalid_dbt_method_yaml_selectors(selector_name, selector_definition, exception_msg): +def test_invalid_yaml_structure_raises_immediately(selector_name, selector_definition, exception_msg): + """Test that structural YAML errors (missing name/definition) raise immediately from parse().""" selectors = {selector_name: selector_definition} with pytest.raises(CosmosValueError) as err_info: _ = YamlSelectors.parse(selectors) - assert err_info.value.args[0] == exception_msg + assert exception_msg in str(err_info.value) @pytest.mark.parametrize( @@ -1577,11 +1735,12 @@ def test_invalid_dbt_method_yaml_selectors(selector_name, selector_definition, e ) def test_invalid_value_type_for_list_keys(selector_name, definition_key, invalid_value, expected_error_fragment): selectors = {selector_name: {"name": selector_name, "definition": {definition_key: invalid_value}}} + yaml_selectors = YamlSelectors.parse(selectors) with pytest.raises(CosmosValueError) as err_info: - _ = YamlSelectors.parse(selectors) + _ = yaml_selectors.get_parsed(selector_name) - assert expected_error_fragment in err_info.value.args[0] + assert expected_error_fragment in str(err_info.value) @pytest.mark.parametrize( @@ -1603,13 +1762,15 @@ def test_invalid_items_in_selector_lists(selector_name, definition_key, invalid_ } } + yaml_selectors = YamlSelectors.parse(selectors) + with pytest.raises(CosmosValueError) as err_info: - _ = YamlSelectors.parse(selectors) + _ = yaml_selectors.get_parsed(selector_name) - assert ( + expected_error = ( f'Invalid value type {item_type_str} in key "{definition_key}", expected dict (value: {invalid_list_item}).' - in err_info.value.args[0] ) + assert expected_error in str(err_info.value) @pytest.mark.parametrize( @@ -1630,11 +1791,19 @@ def test_non_string_keys_in_selector_dicts(selector_name, definition_key, non_st } } - with pytest.raises(CosmosValueError) as err_info: - _ = YamlSelectors.parse(selectors) + yaml_selectors = YamlSelectors.parse(selectors) - assert f'Expected all keys to "{definition_key}" dict to be strings' in err_info.value.args[0] - assert f'but "{non_string_key}" is a "{key_type_str}"' in err_info.value.args[0] + with pytest.raises(CosmosValueError) as err_info: + _ = yaml_selectors.get_parsed(selector_name) + + error_message = str(err_info.value) + assert all( + part in error_message + for part in [ + f'Expected all keys to "{definition_key}" dict to be strings', + f'but "{non_string_key}" is a "{key_type_str}"', + ] + ) def test_selector_reference_resolves_from_cache(): diff --git a/tests/operators/_watcher/test_triggerer.py b/tests/operators/_watcher/test_triggerer.py index d620509a1d..07623e0300 100644 --- a/tests/operators/_watcher/test_triggerer.py +++ b/tests/operators/_watcher/test_triggerer.py @@ -81,21 +81,40 @@ async def runner(*args, **kwargs): assert none_result is None @pytest.mark.parametrize( - "use_event, xcom_val, expected_status", + "use_event, xcom_val, expected_status, expected_compiled_sql", [ - (True, {"data": {"run_result": {"status": "success"}}}, "success"), - (True, None, None), - (False, "failed", "failed"), + # Event mode: status from event payload; compiled_sql from canonical *_compiled_sql key only + (True, {"data": {"run_result": {"status": "success"}}}, "success", "SELECT 1"), + (True, {"data": {"run_result": {"status": "success"}}}, "success", None), + (True, None, None, None), + # Subprocess mode: status from *_status key; compiled_sql from canonical key + (False, "failed", "failed", None), + (False, "success", "success", "SELECT * FROM table"), ], ) - async def test_parse_node_status(self, use_event, xcom_val, expected_status): + async def test_parse_node_status_and_compiled_sql( + self, use_event, xcom_val, expected_status, expected_compiled_sql + ): self.trigger.use_event = use_event + + async def mock_get_xcom_val(key): + # compiled_sql is always read from the canonical key (same for both modes) + if key.endswith("_compiled_sql"): + return expected_compiled_sql + if use_event: + return xcom_val if xcom_val else None + # Subprocess mode: status from per-model key + if key.endswith("_status"): + return xcom_val + return None + with ( patch("cosmos.operators._watcher.triggerer._parse_compressed_xcom", return_value=xcom_val), - patch.object(self.trigger, "get_xcom_val", AsyncMock(return_value=xcom_val)), + patch.object(self.trigger, "get_xcom_val", AsyncMock(side_effect=mock_get_xcom_val)), ): - status = await self.trigger._parse_node_status() + status, compiled_sql = await self.trigger._parse_node_status_and_compiled_sql() assert status == expected_status + assert compiled_sql == expected_compiled_sql @pytest.mark.parametrize( "airflow_version, expected_val", @@ -127,6 +146,9 @@ async def test_get_xcom_val_branches(self, airflow_version, expected_val): async def test_run_various_outcomes(self, node_status, producer_state, expected): async def fake_get_xcom_val(key): + # Return None for compiled_sql key so payload matches expected (no compiled_sql) + if key.endswith("_compiled_sql"): + return None return "compressed_data" with ( @@ -210,13 +232,13 @@ def _import_side_effect(name: str, *args, **kwargs): async def test_run_producer_success_model_not_run(self, caplog): """Test that when producer succeeds but model has no status, trigger yields success with model_not_run reason.""" get_producer_status_mock = AsyncMock(return_value="success") - parse_node_status_mock = AsyncMock(return_value=None) + parse_node_status_and_compiled_sql_mock = AsyncMock(return_value=(None, None)) caplog.set_level("INFO") with ( patch.object(self.trigger, "_get_producer_task_status", get_producer_status_mock), - patch.object(self.trigger, "_parse_node_status", parse_node_status_mock), + patch.object(self.trigger, "_parse_node_status_and_compiled_sql", parse_node_status_and_compiled_sql_mock), ): events = [] async for event in self.trigger.run(): @@ -231,14 +253,16 @@ async def test_run_producer_success_model_not_run(self, caplog): async def test_run_poke_interval_and_debug_log(self, caplog): get_xcom_val_mock = AsyncMock(side_effect=["compressed_data"]) get_producer_status_mock = AsyncMock(side_effect=["running", "running", "running"]) - parse_node_status_mock = AsyncMock(side_effect=[None, None, "success"]) + parse_node_status_and_compiled_sql_mock = AsyncMock( + side_effect=[(None, None), (None, None), ("success", "SELECT 1")] + ) caplog.set_level("DEBUG") with ( patch.object(self.trigger, "get_xcom_val", get_xcom_val_mock), patch.object(self.trigger, "_get_producer_task_status", get_producer_status_mock), - patch("cosmos.operators._watcher.triggerer.WatcherTrigger._parse_node_status", parse_node_status_mock), + patch.object(self.trigger, "_parse_node_status_and_compiled_sql", parse_node_status_and_compiled_sql_mock), patch("asyncio.sleep", new_callable=AsyncMock) as sleep_mock, ): events = [] @@ -248,3 +272,18 @@ async def test_run_poke_interval_and_debug_log(self, caplog): sleep_mock.assert_awaited() assert events[0].payload["status"] == "success" + assert events[0].payload["compiled_sql"] == "SELECT 1" + + @pytest.mark.asyncio + async def test_run_failed_model_includes_compiled_sql_in_event(self): + """When model fails and compiled_sql is available, event payload includes it.""" + parse_mock = AsyncMock(return_value=("failed", "SELECT * FROM broken_model")) + with ( + patch.object(self.trigger, "_get_producer_task_status", AsyncMock(return_value="running")), + patch.object(self.trigger, "_parse_node_status_and_compiled_sql", parse_mock), + ): + events = [event async for event in self.trigger.run()] + assert len(events) == 1 + assert events[0].payload["status"] == "failed" + assert events[0].payload["reason"] == "model_failed" + assert events[0].payload["compiled_sql"] == "SELECT * FROM broken_model" diff --git a/tests/operators/_watcher/test_watcher_base.py b/tests/operators/_watcher/test_watcher_base.py index 775eeb66f1..350081b9b3 100644 --- a/tests/operators/_watcher/test_watcher_base.py +++ b/tests/operators/_watcher/test_watcher_base.py @@ -22,3 +22,34 @@ class SubclassBaseConsumerSensor(BaseConsumerSensor, DbtRunLocalOperator): with pytest.raises(NotImplementedError): assert sensor._get_status_from_events(None, None) is None + + def test_extra_context_is_stored_on_instance(self): + """Consumer sensor stores extra_context so it is available at runtime.""" + + class SubclassBaseConsumerSensor(BaseConsumerSensor, DbtRunLocalOperator): + something_to_be_implemented = True + + extra_context = {"dbt_node_config": {"unique_id": "model.jaffle_shop.stg_orders"}, "run_id": "run_123"} + sensor = SubclassBaseConsumerSensor( + task_id="test_sensor", + producer_task_id="dbt_run_local", + profile_config=None, + project_dir="/tmp/sample_project", + extra_context=extra_context, + ) + assert sensor.extra_context == extra_context + assert sensor.model_unique_id == "model.jaffle_shop.stg_orders" + + def test_extra_context_defaults_to_empty_dict_when_not_passed(self): + """When extra_context is not in kwargs, sensor.extra_context is {}.""" + + class SubclassBaseConsumerSensor(BaseConsumerSensor, DbtRunLocalOperator): + something_to_be_implemented = True + + sensor = SubclassBaseConsumerSensor( + task_id="test_sensor", + producer_task_id="dbt_run_local", + profile_config=None, + project_dir="/tmp/sample_project", + ) + assert sensor.extra_context == {} diff --git a/tests/operators/test_watcher.py b/tests/operators/test_watcher.py index 7120216a1a..81af69683a 100644 --- a/tests/operators/test_watcher.py +++ b/tests/operators/test_watcher.py @@ -18,6 +18,7 @@ from cosmos import DbtDag, ExecutionConfig, ProfileConfig, ProjectConfig, RenderConfig, TestBehavior from cosmos.config import InvocationMode from cosmos.constants import PRODUCER_WATCHER_DEFAULT_PRIORITY_WEIGHT, ExecutionMode +from cosmos.operators._watcher.base import store_compiled_sql_for_model from cosmos.operators._watcher.triggerer import WatcherTrigger from cosmos.operators.watcher import ( DbtBuildWatcherOperator, @@ -260,11 +261,11 @@ def test_dbt_producer_watcher_operator_skips_retry_attempt(caplog): ({"status": "success", "reason": "model_not_run"}, None), ( {"status": "failed", "reason": "model_failed"}, - "dbt model 'model.pkg.m' failed. Review the producer task 'dbt_producer_watcher' logs for details.", + "dbt model 'model.pkg.m' failed. Review the producer task 'dbt_producer_watcher_operator' logs for details.", ), ( {"status": "failed", "reason": "producer_failed"}, - "Watcher producer task 'dbt_producer_watcher' failed before reporting model results. Check its logs for the underlying error.", + "Watcher producer task 'dbt_producer_watcher_operator' failed before reporting model results. Check its logs for the underlying error.", ), ], ) @@ -325,10 +326,8 @@ def test_handle_node_finished_injects_compiled_sql(tmp_path, monkeypatch): ev = _fake_event(name="NodeFinished", uid="model.pkg.my_model", resource_type="model", node_path="my_model.sql") op._handle_node_finished(ev, ctx) - stored = list(ti.store.values())[0] - raw = zlib.decompress(base64.b64decode(stored)).decode() - data = json.loads(raw) - assert data.get("compiled_sql") == sql_text + # Compiled SQL is pushed to the canonical XCom key (single strategy for both event and subprocess) + assert ti.store.get("model__pkg__my_model_compiled_sql") == sql_text def test_handle_node_finished_without_compiled_sql_does_not_inject(tmp_path, monkeypatch): @@ -343,10 +342,11 @@ def test_handle_node_finished_without_compiled_sql_does_not_inject(tmp_path, mon ev = _fake_event(name="NodeFinished", uid="model.pkg.my_model", resource_type="model", node_path="my_model.sql") op._handle_node_finished(ev, ctx) + # Event payload does not contain compiled_sql; canonical key is only set when extraction succeeds stored = list(ti.store.values())[0] - raw = zlib.decompress(base64.b64decode(stored)).decode() - data = json.loads(raw) + data = json.loads(zlib.decompress(base64.b64decode(stored)).decode()) assert "compiled_sql" not in data + assert "model__pkg__my_model_compiled_sql" not in ti.store def test_execute_streaming_mode(): @@ -620,6 +620,136 @@ def test_process_log_line_callable_integration_with_subprocess_pattern(self): assert ti.store.get("model__pkg__model_b_status") == "failed" assert len(ti.store) == 2 # Only success and failed statuses are stored + def test_store_dbt_resource_status_from_log_pushes_compiled_sql_for_models(self, tmp_path): + """Test that compiled_sql is pushed to XCom for successful model nodes.""" + ti = _MockTI() + ctx = {"ti": ti} + + # Create a fake compiled SQL file + compiled_dir = tmp_path / "target" / "compiled" / "pkg" / "models" + compiled_dir.mkdir(parents=True) + compiled_sql_file = compiled_dir / "my_model.sql" + compiled_sql_file.write_text("SELECT * FROM orders WHERE status = 'completed'") + + # JSON log with node_path and resource_type + log_line = json.dumps( + { + "data": { + "node_info": { + "node_status": "success", + "unique_id": "model.pkg.my_model", + "node_path": "my_model.sql", + "resource_type": "model", + } + } + } + ) + + store_dbt_resource_status_from_log(log_line, {"context": ctx, "project_dir": str(tmp_path)}) + + assert ti.store.get("model__pkg__my_model_status") == "success" + assert ti.store.get("model__pkg__my_model_compiled_sql") == "SELECT * FROM orders WHERE status = 'completed'" + + def test_store_dbt_resource_status_from_log_no_compiled_sql_for_non_models(self, tmp_path): + """Test that compiled_sql is not pushed for non-model resources like seeds or tests.""" + ti = _MockTI() + ctx = {"ti": ti} + + # JSON log for a seed (not a model) + log_line = json.dumps( + { + "data": { + "node_info": { + "node_status": "success", + "unique_id": "seed.pkg.my_seed", + "node_path": "my_seed.csv", + "resource_type": "seed", + } + } + } + ) + + store_dbt_resource_status_from_log(log_line, {"context": ctx, "project_dir": str(tmp_path)}) + + assert ti.store.get("seed__pkg__my_seed_status") == "success" + assert "seed__pkg__my_seed_compiled_sql" not in ti.store + + def test_store_dbt_resource_status_from_log_pushes_compiled_sql_on_failure(self, tmp_path): + """Test that compiled_sql is pushed for failed models too (useful for debugging).""" + ti = _MockTI() + ctx = {"ti": ti} + + # Create a fake compiled SQL file (compilation happens before execution, so it exists even if model fails) + compiled_dir = tmp_path / "target" / "compiled" / "pkg" / "models" + compiled_dir.mkdir(parents=True) + compiled_sql_file = compiled_dir / "failed_model.sql" + compiled_sql_file.write_text("SELECT * FROM orders WHERE invalid_column = 1") + + log_line = json.dumps( + { + "data": { + "node_info": { + "node_status": "failed", + "unique_id": "model.pkg.failed_model", + "node_path": "failed_model.sql", + "resource_type": "model", + } + } + } + ) + + store_dbt_resource_status_from_log(log_line, {"context": ctx, "project_dir": str(tmp_path)}) + + assert ti.store.get("model__pkg__failed_model_status") == "failed" + assert ti.store.get("model__pkg__failed_model_compiled_sql") == "SELECT * FROM orders WHERE invalid_column = 1" + + +class TestStoreCompiledSqlForModelPathHandling: + """Tests for store_compiled_sql_for_model (node_path is relative to target/compiled//models/).""" + + def test_missing_node_path_does_not_push(self, tmp_path): + """When node_path is None or empty, no compiled_sql is extracted or pushed.""" + ti = _MockTI() + store_compiled_sql_for_model(ti, str(tmp_path), "model.pkg.m", None, "model") + assert "model__pkg__m_compiled_sql" not in ti.store + + ti2 = _MockTI() + store_compiled_sql_for_model(ti2, str(tmp_path), "model.pkg.m", "", "model") + assert "model__pkg__m_compiled_sql" not in ti2.store + + def test_nonexistent_path_does_not_push(self, tmp_path): + """When compiled SQL path does not exist, we do not push.""" + ti = _MockTI() + store_compiled_sql_for_model(ti, str(tmp_path), "model.pkg.my_model", "nonexistent.sql", "model") + assert "model__pkg__my_model_compiled_sql" not in ti.store + + def test_path_traversal_outside_project_does_not_push(self, tmp_path): + """node_path with .. that resolves outside compiled root: file does not exist, so we do not push.""" + compiled_dir = tmp_path / "target" / "compiled" / "pkg" / "models" + compiled_dir.mkdir(parents=True) + (compiled_dir / "legit.sql").write_text("SELECT 1") + ti = _MockTI() + store_compiled_sql_for_model(ti, str(tmp_path), "model.pkg.m", "../../../etc/passwd", "model") + assert "model__pkg__m_compiled_sql" not in ti.store + + def test_compiled_sql_under_models_pushed(self, tmp_path): + """node_path relative to models/ (e.g. staging/foo.sql) is read and pushed.""" + compiled_dir = tmp_path / "target" / "compiled" / "pkg" / "models" / "staging" + compiled_dir.mkdir(parents=True) + (compiled_dir / "stg_orders.sql").write_text("SELECT * FROM staging.orders") + ti = _MockTI() + store_compiled_sql_for_model(ti, str(tmp_path), "model.pkg.stg_orders", "staging/stg_orders.sql", "model") + assert ti.store.get("model__pkg__stg_orders_compiled_sql") == "SELECT * FROM staging.orders" + + def test_compiled_sql_flat_path_pushed(self, tmp_path): + """node_path as single file under models/ (e.g. foo.sql) is read and pushed.""" + compiled_dir = tmp_path / "target" / "compiled" / "pkg" / "models" + compiled_dir.mkdir(parents=True) + (compiled_dir / "foo.sql").write_text("SELECT 1") + ti = _MockTI() + store_compiled_sql_for_model(ti, str(tmp_path), "model.pkg.foo", "foo.sql", "model") + assert ti.store.get("model__pkg__foo_compiled_sql") == "SELECT 1" + @patch("cosmos.dbt.runner.is_available", return_value=False) @patch("cosmos.operators.watcher.DbtLocalBaseOperator.execute", return_value="done") @@ -660,7 +790,7 @@ def make_sensor(self, **kwargs): task_id="model.my_model", project_dir="/tmp/project", profile_config=None, - deferrable=True, + deferrable=kwargs.pop("deferrable", True), **kwargs, ) @@ -813,6 +943,47 @@ def test_poke_success_from_run_results(self): result = sensor.poke(context) assert result is True + @patch("cosmos.operators.local.AbstractDbtLocalBase._override_rtif") + def test_poke_subprocess_mode_extracts_compiled_sql_from_xcom(self, mock_override_rtif): + """Test that in subprocess mode, poke extracts compiled_sql from per-model XCom key when status is success.""" + sensor = self.make_sensor() + sensor.invocation_mode = "SUBPROCESS" + sensor.model_unique_id = MODEL_UNIQUE_ID + + ti = MagicMock() + ti.try_number = 1 + # First call returns status from per-model XCom key, second call returns compiled_sql from per-model XCom key + ti.xcom_pull.side_effect = ["success", "SELECT * FROM orders"] + context = self.make_context(ti) + + assert sensor.compiled_sql == "" # Initially empty + result = sensor.poke(context) + assert result is True + assert sensor.compiled_sql == "SELECT * FROM orders" + mock_override_rtif.assert_called_once_with(context) + + @patch("cosmos.operators.watcher.DbtConsumerWatcherSensor.use_event", return_value=True) + @patch("cosmos.operators.watcher.DbtConsumerWatcherSensor._get_producer_task_status", return_value="running") + @patch("cosmos.operators.local.AbstractDbtLocalBase._override_rtif") + def test_poke_event_mode_extracts_compiled_sql_from_canonical_key( + self, mock_override_rtif, mock_get_producer, mock_use_event + ): + """Test that in event (DBT_RUNNER) mode, poke gets compiled_sql from canonical *_compiled_sql key after status.""" + sensor = self.make_sensor() + sensor.model_unique_id = MODEL_UNIQUE_ID + + ti = MagicMock() + ti.try_number = 1 + # _get_status_from_events: dbt_startup_events=None, nodefinished_*=ENCODED_EVENT; then get_xcom_val(compiled_sql_key) + ti.xcom_pull.side_effect = [None, ENCODED_EVENT, "SELECT * FROM orders"] + context = self.make_context(ti) + + assert sensor.compiled_sql == "" # Initially empty + result = sensor.poke(context) + assert result is True + assert sensor.compiled_sql == "SELECT * FROM orders" + mock_override_rtif.assert_called_once_with(context) + @patch("cosmos.operators.watcher.DbtConsumerWatcherSensor._get_producer_task_status", return_value=None) def _fallback_to_non_watcher_run(self, mock_get_producer_task_status): sensor = self.make_sensor() @@ -917,17 +1088,18 @@ def test_get_status_from_events_none(self): result = sensor._get_status_from_events(ti, context) assert result is None - def test_get_status_from_events_sets_compiled_sql(self): + def test_get_status_from_events_does_not_set_compiled_sql_from_event(self): + """compiled_sql is no longer in the event payload; it is read from the canonical XCom key in poke().""" sensor = self.make_sensor() ti = MagicMock() - event_payload = {"data": {"run_result": {"status": "success"}}, "compiled_sql": "select 42"} + event_payload = {"data": {"run_result": {"status": "success"}}} encoded_event = base64.b64encode(zlib.compress(json.dumps(event_payload).encode())).decode("utf-8") ti.xcom_pull.side_effect = [None, encoded_event] context = self.make_context(ti) result = sensor._get_status_from_events(ti, context) assert result == "success" - assert sensor.compiled_sql == "select 42" + assert sensor.compiled_sql == "" # not set from event; poke() will get it from canonical key @patch("cosmos.operators._watcher.base.get_xcom_val") def test_producer_state_failed(self, mock_get_xcom_val): @@ -1013,6 +1185,17 @@ def test_sensor_not_deferred(self, mock_poke): sensor.execute(context=context) mock_poke.assert_called_once() + @patch("cosmos.operators.watcher.DbtConsumerWatcherSensor.poke") + def test_deferrable_false_via_constructor_does_not_defer(self, mock_poke): + """operator_args={'deferrable': False} is respected: sensor created with deferrable=False does not defer.""" + mock_poke.return_value = True + sensor = self.make_sensor(deferrable=False) + assert sensor.deferrable is False + context = {"run_id": "run_id", "task_instance": Mock()} + sensor.execute(context=context) + mock_poke.assert_called_once() + # No TaskDeferred raised: sensor ran synchronously and completed + @pytest.mark.parametrize( "mock_event", [ @@ -1029,6 +1212,32 @@ def test_execute_complete(self, mock_event): else: assert sensor.execute_complete(context=Mock(), event=mock_event) is None + @patch("cosmos.operators.local.AbstractDbtLocalBase._override_rtif") + def test_execute_complete_extracts_compiled_sql(self, mock_override_rtif): + """Test that execute_complete extracts compiled_sql from the event and sets it on the sensor.""" + sensor = self.make_sensor() + context = Mock() + + assert sensor.compiled_sql == "" # Initially empty + + event = {"status": "success", "compiled_sql": "SELECT * FROM orders WHERE status = 'active'"} + sensor.execute_complete(context=context, event=event) + + assert sensor.compiled_sql == "SELECT * FROM orders WHERE status = 'active'" + mock_override_rtif.assert_called_once_with(context) + + def test_execute_complete_handles_missing_compiled_sql(self): + """Test that execute_complete handles events without compiled_sql gracefully.""" + sensor = self.make_sensor() + context = Mock() + + assert sensor.compiled_sql == "" # Initially empty + + event = {"status": "success"} # No compiled_sql in event + sensor.execute_complete(context=context, event=event) + + assert sensor.compiled_sql == "" # Should remain empty + class TestDbtBuildWatcherOperator: def test_dbt_build_watcher_operator_raises_not_implemented_error(self): @@ -1042,6 +1251,82 @@ def test_dbt_build_watcher_operator_raises_not_implemented_error(self): @pytest.mark.skipif(AIRFLOW_VERSION < Version("2.7"), reason="Airflow did not have dag.test() until the 2.6 release") +class TestWatcherTrigger: + """Tests for WatcherTrigger compiled_sql extraction.""" + + def make_trigger(self, use_event: bool = False): + return WatcherTrigger( + model_unique_id="model.pkg.my_model", + producer_task_id="dbt_producer_watcher", + dag_id="test_dag", + run_id="test_run", + map_index=None, + use_event=use_event, + poke_interval=1.0, + ) + + @pytest.mark.asyncio + async def test_parse_node_status_and_compiled_sql_subprocess_mode(self): + """Test that compiled_sql is extracted from XCom in subprocess mode.""" + trigger = self.make_trigger(use_event=False) + + # Mock get_xcom_val to return status and compiled_sql + async def mock_get_xcom_val(key): + if key == "model__pkg__my_model_status": + return "success" + elif key == "model__pkg__my_model_compiled_sql": + return "SELECT * FROM orders" + return None + + trigger.get_xcom_val = mock_get_xcom_val + + status, compiled_sql = await trigger._parse_node_status_and_compiled_sql() + + assert status == "success" + assert compiled_sql == "SELECT * FROM orders" + + @pytest.mark.asyncio + async def test_parse_node_status_and_compiled_sql_subprocess_no_compiled_sql(self): + """Test that missing compiled_sql is handled gracefully in subprocess mode.""" + trigger = self.make_trigger(use_event=False) + + # Mock get_xcom_val to return only status + async def mock_get_xcom_val(key): + if key == "model__pkg__my_model_status": + return "success" + return None + + trigger.get_xcom_val = mock_get_xcom_val + + status, compiled_sql = await trigger._parse_node_status_and_compiled_sql() + + assert status == "success" + assert compiled_sql is None + + @pytest.mark.asyncio + async def test_parse_node_status_and_compiled_sql_dbt_runner_mode(self): + """Test that in dbt_runner mode status comes from event payload and compiled_sql from canonical key.""" + trigger = self.make_trigger(use_event=True) + + # Event payload (no longer contains compiled_sql; it is stored under canonical key) + event_data = {"data": {"run_result": {"status": "success"}}} + compressed = base64.b64encode(zlib.compress(json.dumps(event_data).encode())).decode() + + async def mock_get_xcom_val(key): + if key == "nodefinished_model__pkg__my_model": + return compressed + if key == "model__pkg__my_model_compiled_sql": + return "SELECT id FROM users" + return None + + trigger.get_xcom_val = mock_get_xcom_val + + status, compiled_sql = await trigger._parse_node_status_and_compiled_sql() + + assert status == "success" + assert compiled_sql == "SELECT id FROM users" + + @pytest.mark.integration def test_dbt_dag_with_watcher(capsys): """ @@ -1465,7 +1750,7 @@ def test_dbt_task_group_with_watcher_has_correct_templated_dbt_cmd(): assert full_cmd[1] == "run" # dbt run command cmd = " ".join(full_cmd) - assert "--select stg_customers" in cmd + assert "--select fqn:jaffle_shop.staging.stg_customers" in cmd assert "--threads 1" in cmd diff --git a/tests/operators/test_watcher_kubernetes_unit.py b/tests/operators/test_watcher_kubernetes_unit.py index aece178fb3..a0fddd9e60 100644 --- a/tests/operators/test_watcher_kubernetes_unit.py +++ b/tests/operators/test_watcher_kubernetes_unit.py @@ -224,6 +224,123 @@ def test_use_event_returns_false(): assert sensor.use_event() is False +class TestCallbacksNormalization: + """Tests for the callbacks normalization logic in DbtProducerWatcherKubernetesOperator.""" + + def test_callbacks_none_adds_watcher_callback(self): + """ + Test that when callbacks is None, WatcherKubernetesCallback is added. + """ + from cosmos.operators.watcher_kubernetes import WatcherKubernetesCallback + + op = DbtProducerWatcherKubernetesOperator( + project_dir=".", + profile_config=None, + image="dbt-image:latest", + callbacks=None, + ) + assert op.callbacks == [WatcherKubernetesCallback] + + def test_callbacks_not_provided_adds_watcher_callback(self): + """ + Test that when callbacks is not provided, WatcherKubernetesCallback is added. + """ + from cosmos.operators.watcher_kubernetes import WatcherKubernetesCallback + + op = DbtProducerWatcherKubernetesOperator( + project_dir=".", + profile_config=None, + image="dbt-image:latest", + ) + assert op.callbacks == [WatcherKubernetesCallback] + + def test_callbacks_list_appends_watcher_callback(self): + """ + Test that when callbacks is a list, WatcherKubernetesCallback is appended. + """ + from cosmos.operators.watcher_kubernetes import WatcherKubernetesCallback + + class CustomCallback: + pass + + op = DbtProducerWatcherKubernetesOperator( + project_dir=".", + profile_config=None, + image="dbt-image:latest", + callbacks=[CustomCallback], + ) + assert op.callbacks == [CustomCallback, WatcherKubernetesCallback] + + def test_callbacks_tuple_appends_watcher_callback(self): + """ + Test that when callbacks is a tuple, WatcherKubernetesCallback is appended. + """ + from cosmos.operators.watcher_kubernetes import WatcherKubernetesCallback + + class CustomCallback: + pass + + op = DbtProducerWatcherKubernetesOperator( + project_dir=".", + profile_config=None, + image="dbt-image:latest", + callbacks=(CustomCallback,), + ) + assert op.callbacks == [CustomCallback, WatcherKubernetesCallback] + + def test_callbacks_single_value_wraps_and_appends_watcher_callback(self): + """ + Test that when callbacks is a single value (not list/tuple), it is wrapped in a list + and WatcherKubernetesCallback is appended. + """ + from cosmos.operators.watcher_kubernetes import WatcherKubernetesCallback + + class CustomCallback: + pass + + op = DbtProducerWatcherKubernetesOperator( + project_dir=".", + profile_config=None, + image="dbt-image:latest", + callbacks=CustomCallback, + ) + assert op.callbacks == [CustomCallback, WatcherKubernetesCallback] + + def test_callbacks_empty_list_adds_watcher_callback(self): + """ + Test that when callbacks is an empty list, WatcherKubernetesCallback is added. + """ + from cosmos.operators.watcher_kubernetes import WatcherKubernetesCallback + + op = DbtProducerWatcherKubernetesOperator( + project_dir=".", + profile_config=None, + image="dbt-image:latest", + callbacks=[], + ) + assert op.callbacks == [WatcherKubernetesCallback] + + def test_callbacks_multiple_values_appends_watcher_callback(self): + """ + Test that when callbacks contains multiple values, WatcherKubernetesCallback is appended. + """ + from cosmos.operators.watcher_kubernetes import WatcherKubernetesCallback + + class CustomCallback1: + pass + + class CustomCallback2: + pass + + op = DbtProducerWatcherKubernetesOperator( + project_dir=".", + profile_config=None, + image="dbt-image:latest", + callbacks=[CustomCallback1, CustomCallback2], + ) + assert op.callbacks == [CustomCallback1, CustomCallback2, WatcherKubernetesCallback] + + def test_callbacks_included_in_producer_operator(): """ Test that the WatcherKubernetesCallback is included in the callbacks of the DbtProducerWatcherKubernetesOperator. diff --git a/tests/profiles/snowflake/test_snowflake_user_encrypted_privatekey_env_variable.py b/tests/profiles/snowflake/test_snowflake_user_encrypted_privatekey_env_variable.py index 3c497de1af..0f7a6bc0a3 100644 --- a/tests/profiles/snowflake/test_snowflake_user_encrypted_privatekey_env_variable.py +++ b/tests/profiles/snowflake/test_snowflake_user_encrypted_privatekey_env_variable.py @@ -183,6 +183,7 @@ def test_profile_args( "account": f"{mock_account}.{mock_region}", "database": mock_snowflake_conn.extra_dejson.get("database"), "warehouse": mock_snowflake_conn.extra_dejson.get("warehouse"), + "threads": 4, } @@ -212,6 +213,7 @@ def test_profile_args_overrides( "account": f"{mock_account}.{mock_region}", "database": "my_db_override", "warehouse": mock_snowflake_conn.extra_dejson.get("warehouse"), + "threads": 4, } @@ -275,4 +277,5 @@ def test_old_snowflake_format() -> None: "account": conn.extra_dejson.get("account"), "database": conn.extra_dejson.get("database"), "warehouse": conn.extra_dejson.get("warehouse"), + "threads": 4, } diff --git a/tests/profiles/snowflake/test_snowflake_user_encrypted_privatekey_file.py b/tests/profiles/snowflake/test_snowflake_user_encrypted_privatekey_file.py index d61cae75af..92e614c996 100644 --- a/tests/profiles/snowflake/test_snowflake_user_encrypted_privatekey_file.py +++ b/tests/profiles/snowflake/test_snowflake_user_encrypted_privatekey_file.py @@ -136,6 +136,7 @@ def test_profile_args( "account": f"{mock_account}.{mock_region}", "database": mock_snowflake_conn.extra_dejson.get("database"), "warehouse": mock_snowflake_conn.extra_dejson.get("warehouse"), + "threads": 4, } @@ -165,6 +166,7 @@ def test_profile_args_overrides( "account": f"{mock_account}.{mock_region}", "database": "my_db_override", "warehouse": mock_snowflake_conn.extra_dejson.get("warehouse"), + "threads": 4, } @@ -213,4 +215,5 @@ def test_old_snowflake_format() -> None: "account": conn.extra_dejson.get("account"), "database": conn.extra_dejson.get("database"), "warehouse": conn.extra_dejson.get("warehouse"), + "threads": 4, } diff --git a/tests/profiles/snowflake/test_snowflake_user_pass.py b/tests/profiles/snowflake/test_snowflake_user_pass.py index 651db27c0e..3d439c7d78 100644 --- a/tests/profiles/snowflake/test_snowflake_user_pass.py +++ b/tests/profiles/snowflake/test_snowflake_user_pass.py @@ -126,6 +126,7 @@ def test_profile_args( "account": mock_snowflake_conn.extra_dejson.get("account"), "database": mock_snowflake_conn.extra_dejson.get("database"), "warehouse": mock_snowflake_conn.extra_dejson.get("warehouse"), + "threads": 4, } @@ -151,6 +152,30 @@ def test_profile_args_overrides( "account": mock_snowflake_conn.extra_dejson.get("account"), "database": "my_db_override", "warehouse": mock_snowflake_conn.extra_dejson.get("warehouse"), + "threads": 4, + } + + +def test_profile_args_overrides_threads( + mock_snowflake_conn: Connection, +) -> None: + """ + Tests that you can override the default threads value via profile_args. + """ + profile_mapping = get_automatic_profile_mapping( + mock_snowflake_conn.conn_id, + profile_args={"threads": 8}, + ) + + assert profile_mapping.profile == { + "type": mock_snowflake_conn.conn_type, + "user": mock_snowflake_conn.login, + "password": "{{ env_var('COSMOS_CONN_SNOWFLAKE_PASSWORD') }}", + "schema": mock_snowflake_conn.schema, + "account": mock_snowflake_conn.extra_dejson.get("account"), + "database": mock_snowflake_conn.extra_dejson.get("database"), + "warehouse": mock_snowflake_conn.extra_dejson.get("warehouse"), + "threads": 8, } @@ -197,6 +222,7 @@ def test_old_snowflake_format() -> None: "account": conn.extra_dejson.get("account"), "database": conn.extra_dejson.get("database"), "warehouse": conn.extra_dejson.get("warehouse"), + "threads": 4, } @@ -230,6 +256,7 @@ def test_appends_region() -> None: "account": f"{conn.extra_dejson.get('account')}.{conn.extra_dejson.get('region')}", "database": conn.extra_dejson.get("database"), "warehouse": conn.extra_dejson.get("warehouse"), + "threads": 4, } diff --git a/tests/profiles/snowflake/test_snowflake_user_privatekey.py b/tests/profiles/snowflake/test_snowflake_user_privatekey.py index f31fe95981..4dccfe3ed1 100644 --- a/tests/profiles/snowflake/test_snowflake_user_privatekey.py +++ b/tests/profiles/snowflake/test_snowflake_user_privatekey.py @@ -184,6 +184,7 @@ def test_profile_args( "account": mock_snowflake_conn.extra_dejson.get("account"), "database": mock_snowflake_conn.extra_dejson.get("database"), "warehouse": mock_snowflake_conn.extra_dejson.get("warehouse"), + "threads": 4, } @@ -209,6 +210,7 @@ def test_profile_args_overrides( "account": mock_snowflake_conn.extra_dejson.get("account"), "database": "my_db_override", "warehouse": mock_snowflake_conn.extra_dejson.get("warehouse"), + "threads": 4, } @@ -269,6 +271,7 @@ def test_old_snowflake_format() -> None: "account": conn.extra_dejson.get("account"), "database": conn.extra_dejson.get("database"), "warehouse": conn.extra_dejson.get("warehouse"), + "threads": 4, } @@ -302,4 +305,5 @@ def test_appends_region() -> None: "account": f"{conn.extra_dejson.get('account')}.{conn.extra_dejson.get('region')}", "database": conn.extra_dejson.get("database"), "warehouse": conn.extra_dejson.get("warehouse"), + "threads": 4, } diff --git a/tests/test_converter.py b/tests/test_converter.py index ef73ec0358..25988d6a8b 100644 --- a/tests/test_converter.py +++ b/tests/test_converter.py @@ -1,9 +1,11 @@ +import logging import tempfile from datetime import datetime from pathlib import Path from unittest.mock import MagicMock, patch import pytest +from airflow.exceptions import ParamValidationError from airflow.models import DAG from cosmos.config import ExecutionConfig, ProfileConfig, ProjectConfig, RenderConfig @@ -525,17 +527,23 @@ def test_converter_creates_dag_with_test_with_multiple_parents_and_build(): assert args[1:] == [ "build", "--select", - "combined_model", + "fqn:my_dbt_project.combined_model", "--exclude", "custom_test_combined_model_combined_model_", ] args = tasks["model.my_dbt_project.model_a"].build_cmd({})[0] - assert args[1:] == ["build", "--select", "model_a", "--exclude", "custom_test_combined_model_combined_model_"] + assert args[1:] == [ + "build", + "--select", + "fqn:my_dbt_project.model_a", + "--exclude", + "custom_test_combined_model_combined_model_", + ] - # The test for model_b should not be changed, since it is not a parent of this test + # The command for model_b should not add any exclusions, since it is not a parent of this test args = tasks["model.my_dbt_project.model_b"].build_cmd({})[0] - assert args[1:] == ["build", "--select", "model_b"] + assert args[1:] == ["build", "--select", "fqn:my_dbt_project.model_b"] # We should have a task dedicated to run the test with multiple parents args = tasks["test.my_dbt_project.custom_test_combined_model_combined_model_.c6e4587380"].build_cmd({})[0] @@ -1274,6 +1282,41 @@ def test_telemetry_metadata_storage(mock_load_dbt_graph, mock_should_emit): assert "database" in metadata +@patch("cosmos.converter.Param", side_effect=ParamValidationError("Invalid param")) +@patch("cosmos.converter.should_emit", return_value=True) +@patch("cosmos.converter.DbtGraph.load") +def test_telemetry_metadata_storage_handles_param_validation_error( + mock_load_dbt_graph, mock_should_emit, mock_param, caplog +): + """Test that ParamValidationError during telemetry metadata storage is caught and logged as a warning.""" + dag = DAG("test_dag_telemetry_param_error", start_date=datetime(2024, 1, 1)) + + project_config = ProjectConfig(dbt_project_path=SAMPLE_DBT_PROJECT) + profile_config = ProfileConfig( + profile_name="test", + target_name="test", + profile_mapping=PostgresUserPasswordProfileMapping(conn_id="test", profile_args={}), + ) + execution_config = ExecutionConfig(execution_mode=ExecutionMode.LOCAL) + render_config = RenderConfig() + + with caplog.at_level(logging.WARNING, logger="cosmos.converter"): + # Should NOT raise, even though Param raises ParamValidationError + _ = DbtToAirflowConverter( + dag=dag, + project_config=project_config, + profile_config=profile_config, + execution_config=execution_config, + render_config=render_config, + ) + + # Verify a warning was logged + "Failed to store compressed Cosmos telemetry metadata in DAG test_dag_telemetry_param_error params" in caplog.text + + # Verify metadata was NOT stored in dag.params (since the assignment failed) + assert "__cosmos_telemetry_metadata__" not in dag.params + + @patch("cosmos.converter.DbtGraph.load") @patch("cosmos.converter.should_emit", return_value=False) def test_telemetry_metadata_not_stored_when_disabled(mock_should_emit, mock_load_dbt_graph):