Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 53 additions & 1 deletion cosmos/airflow/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
DBT_SETUP_ASYNC_TASK_ID,
DBT_TEARDOWN_ASYNC_TASK_ID,
DEFAULT_DBT_RESOURCES,
PRODUCER_WATCHER_TASK_ID,
SUPPORTED_BUILD_RESOURCES,
TESTABLE_DBT_RESOURCES,
DbtResourceType,
Expand Down Expand Up @@ -211,6 +212,7 @@ def _get_task_id_and_args(
node: DbtNode,
args: dict[str, Any],
use_task_group: bool,
execution_mode: ExecutionMode,
normalize_task_id: Callable[..., Any] | None,
normalize_task_display_name: Callable[..., Any] | None,
resource_suffix: str,
Expand Down Expand Up @@ -338,6 +340,7 @@ def create_task_metadata(
normalize_task_display_name=normalize_task_display_name,
resource_suffix=resource_suffix,
include_resource_type=True,
execution_mode=execution_mode,
)
elif node.resource_type == DbtResourceType.SOURCE:
args["select"] = f"source:{node.resource_name}"
Expand All @@ -353,7 +356,13 @@ def create_task_metadata(
if source_pruning and filtered_nodes and not _is_source_used_by_filtered_nodes(node, filtered_nodes):
return None
task_id, args = _get_task_id_and_args(
node, args, use_task_group, normalize_task_id, normalize_task_display_name, "source"
node=node,
args=args,
use_task_group=use_task_group,
normalize_task_id=normalize_task_id,
normalize_task_display_name=normalize_task_display_name,
resource_suffix=r"source",
execution_mode=execution_mode,
)
if node.has_freshness is False and source_rendering_behavior == SourceRenderingBehavior.ALL:
# render sources without freshness as empty operators
Expand All @@ -372,6 +381,7 @@ def create_task_metadata(
normalize_task_id=normalize_task_id,
normalize_task_display_name=normalize_task_display_name,
resource_suffix=resource_suffix,
execution_mode=execution_mode,
)

_override_profile_if_needed(args, node.profile_config_to_override)
Expand Down Expand Up @@ -531,6 +541,37 @@ def _add_dbt_setup_async_task(
tasks_map[DBT_SETUP_ASYNC_TASK_ID] = setup_airflow_task


def _add_producer_watcher(
dag: DAG,
task_args: dict[str, Any],
tasks_map: dict[str, Any],
task_group: TaskGroup | None,
render_config: RenderConfig | None = None,
Comment thread
tatiana marked this conversation as resolved.
) -> str:

producer_task_args = task_args.copy()

if render_config is not None:
producer_task_args["select"] = render_config.select
producer_task_args["selector"] = render_config.selector
producer_task_args["exclude"] = render_config.exclude

producer_task_metadata = TaskMetadata(
id=PRODUCER_WATCHER_TASK_ID,
operator_class="cosmos.operators.watcher.DbtProducerWatcherOperator",
arguments=producer_task_args,
)
producer_airflow_task = create_airflow_task(producer_task_metadata, dag, task_group=task_group)
for task_id, task in tasks_map.items():
# we want to make the producer task to be the parent of the root dbt nodes, without blocking them from sensing XCom
if not task.upstream_list:
producer_airflow_task >> task
task.trigger_rule = task_args.get("trigger_rule", "always")

tasks_map[PRODUCER_WATCHER_TASK_ID] = producer_airflow_task
return producer_airflow_task.task_id


def should_create_detached_nodes(render_config: RenderConfig) -> bool:
"""
Decide if we should calculate / insert detached nodes into the graph.
Expand Down Expand Up @@ -704,6 +745,16 @@ def build_airflow_graph( # noqa: C901 TODO: https://github.com/astronomer/astro
logger.debug(f"Conversion of <{node.unique_id}> was successful!")
tasks_map[node_id] = task_or_group

if execution_mode == ExecutionMode.WATCHER:
producer_watcher_task_id = _add_producer_watcher(
dag,
task_args,
tasks_map,
task_group,
render_config=render_config,
)
task_args["producer_watcher_task_id"] = producer_watcher_task_id

# If test_behaviour=="after_all", there will be one test task, run by the end of the DAG
# The end of a DAG is defined by the DAG leaf tasks (tasks which do not have downstream tasks)
if test_behavior == TestBehavior.AFTER_ALL:
Expand Down Expand Up @@ -738,6 +789,7 @@ def build_airflow_graph( # noqa: C901 TODO: https://github.com/astronomer/astro
tasks_map[node_id] = test_task

create_airflow_task_dependencies(nodes, tasks_map)

if settings.enable_setup_async_task:
_add_dbt_setup_async_task(
dag,
Expand Down
3 changes: 3 additions & 0 deletions cosmos/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ class ExecutionMode(Enum):
Where the Cosmos tasks should be executed.
"""

WATCHER = "watcher"
LOCAL = "local"
AIRFLOW_ASYNC = "airflow_async"
DOCKER = "docker"
Expand Down Expand Up @@ -167,6 +168,8 @@ def _missing_value_(cls, value): # type: ignore
DBT_SETUP_ASYNC_TASK_ID = "dbt_setup_async"
DBT_TEARDOWN_ASYNC_TASK_ID = "dbt_teardown_async"

PRODUCER_WATCHER_TASK_ID = "dbt_producer_watcher"

TELEMETRY_URL = "https://astronomer.gateway.scarf.sh/astronomer-cosmos/{telemetry_version}/{cosmos_version}/{airflow_version}/{python_version}/{platform_system}/{platform_machine}/{event_type}/{status}/{dag_hash}/{task_count}/{cosmos_task_count}/{execution_modes}"
TELEMETRY_VERSION = "v2"
TELEMETRY_TIMEOUT = 1.0
Expand Down
90 changes: 84 additions & 6 deletions cosmos/operators/watcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import json
import logging
import zlib
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, Sequence

if TYPE_CHECKING: # pragma: no cover
try:
Expand All @@ -18,9 +18,23 @@
from airflow.sensors.base import BaseSensorOperator
from airflow.exceptions import AirflowException

try:
from airflow.providers.standard.operators.empty import EmptyOperator
except ImportError:
from airflow.operators.empty import EmptyOperator # type: ignore[no-redef]

from cosmos.config import ProfileConfig
from cosmos.constants import InvocationMode
from cosmos.operators.local import DbtLocalBaseOperator, DbtRunLocalOperator
from cosmos.constants import PRODUCER_WATCHER_TASK_ID, InvocationMode
from cosmos.operators.base import (
DbtRunMixin,
DbtSeedMixin,
DbtSnapshotMixin,
)
from cosmos.operators.local import (
DbtLocalBaseOperator,
DbtRunLocalOperator,
DbtSourceLocalOperator,
)

try:
from dbt_common.events.base_types import EventMsg
Expand All @@ -30,7 +44,9 @@
logger = logging.getLogger(__name__)


CONSUMER_OPERATOR_DEFAULT_PRIORITY_WEIGHT = 10
PRODUCER_OPERATOR_DEFAULT_PRIORITY_WEIGHT = 9999
WEIGHT_RULE = "absolute" # the default "downstream" does not work with dag.test()


class DbtProducerWatcherOperator(DbtLocalBaseOperator):
Expand Down Expand Up @@ -62,7 +78,8 @@ class DbtProducerWatcherOperator(DbtLocalBaseOperator):

def __init__(self, *args: Any, **kwargs: Any) -> None:
task_id = kwargs.pop("task_id", "dbt_producer_watcher_operator")
kwargs["priority_weight"] = kwargs.get("priority_weight", PRODUCER_OPERATOR_DEFAULT_PRIORITY_WEIGHT)
kwargs.setdefault("priority_weight", PRODUCER_OPERATOR_DEFAULT_PRIORITY_WEIGHT)
kwargs.setdefault("weight_rule", WEIGHT_RULE)
super().__init__(task_id=task_id, *args, **kwargs)

@staticmethod
Expand Down Expand Up @@ -142,12 +159,14 @@ def __init__(
profile_config: ProfileConfig | None = None,
project_dir: str | None = None,
profiles_dir: str | None = None,
producer_task_id: str = "dbt_producer_watcher",
poke_interval: int = 20,
producer_task_id: str = PRODUCER_WATCHER_TASK_ID,
poke_interval: int = 10,
Comment thread
tatiana marked this conversation as resolved.
timeout: int = 60 * 60, # 1 h safety valve
**kwargs: Any,
) -> None:
extra_context = kwargs.pop("extra_context") if "extra_context" in kwargs else {}
kwargs.setdefault("priority_weight", CONSUMER_OPERATOR_DEFAULT_PRIORITY_WEIGHT)
kwargs.setdefault("weight_rule", WEIGHT_RULE)
super().__init__(
poke_interval=poke_interval,
timeout=timeout,
Expand Down Expand Up @@ -279,3 +298,62 @@ def poke(self, context: Context) -> bool:
return True
else:
raise AirflowException(f"Model '{self.model_unique_id}' finished with status '{status}'")


# This Operator does not seem to make sense for this particular execution mode, since build is executed by the producer task.
# That said, it is important to raise an exception if users attempt to use TestBehavior.BUILD, until we have a better experience.
class DbtBuildWatcherOperator:
def __init__(self, *args: Any, **kwargs: Any):
raise NotImplementedError(
"`ExecutionMode.WATCHER` does not expose a DbtBuild operator, since the build command is executed by the producer task."
)


class DbtSeedWatcherOperator(DbtSeedMixin, DbtConsumerWatcherSensor): # type: ignore[misc]
"""
Watches for the progress of dbt seed execution, run by the producer task (DbtProducerWatcherOperator).
"""

template_fields: tuple[str] = DbtConsumerWatcherSensor.template_fields + DbtSeedMixin.template_fields # type: ignore[operator]

def __init__(self, *args: Any, **kwargs: Any):
super().__init__(*args, **kwargs)


class DbtSnapshotWatcherOperator(DbtSnapshotMixin, DbtConsumerWatcherSensor): # type: ignore[misc]
"""
Watches for the progress of dbt snapshot execution, run by the producer task (DbtProducerWatcherOperator).
"""

template_fields: tuple[str] = DbtConsumerWatcherSensor.template_fields # type: ignore[operator]


class DbtSourceWatcherOperator(DbtSourceLocalOperator):
"""
Executes a dbt source freshness command, synchronously, as ExecutionMode.LOCAL.
"""

template_fields: Sequence[str] = DbtSourceLocalOperator.template_fields


class DbtRunWatcherOperator(DbtConsumerWatcherSensor):
"""
Watches for the progress of dbt model execution, run by the producer task (DbtProducerWatcherOperator).
"""

template_fields: tuple[str] = DbtConsumerWatcherSensor.template_fields + DbtRunMixin.template_fields # type: ignore[operator]

def __init__(self, *args: Any, **kwargs: Any):
super().__init__(*args, **kwargs)


class DbtTestWatcherOperator(EmptyOperator):
"""
As a starting point, this operator does nothing.
We'll be implementing this operator as part of: https://github.com/astronomer/astronomer-cosmos/issues/1974
"""

def __init__(self, *args: Any, **kwargs: Any):
desired_keys = ("dag", "task_group", "task_id")
new_kwargs = {key: value for key, value in kwargs.items() if key in desired_keys}
super().__init__(**new_kwargs) # type: ignore[no-untyped-call]
59 changes: 59 additions & 0 deletions dev/dags/example_watcher.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
"""
An example DAG that uses Cosmos to render a dbt project into an Airflow DAG.
"""

import os
from datetime import datetime, timedelta
from pathlib import Path

# [START cosmos_init_imports]
from cosmos import DbtDag, ExecutionConfig, ProfileConfig, ProjectConfig
from cosmos.constants import ExecutionMode

# [END cosmos_init_imports]
from cosmos.profiles import PostgresUserPasswordProfileMapping

DEFAULT_DBT_ROOT_PATH = Path(__file__).parent / "dbt"
DBT_ROOT_PATH = Path(os.getenv("DBT_ROOT_PATH", DEFAULT_DBT_ROOT_PATH))
DBT_PROJECT_NAME = os.getenv("DBT_PROJECT_NAME", "jaffle_shop")
DBT_PROJECT_PATH = DBT_ROOT_PATH / DBT_PROJECT_NAME


profile_config = ProfileConfig(
profile_name="default",
target_name="dev",
profile_mapping=PostgresUserPasswordProfileMapping(
conn_id="example_conn",
profile_args={"schema": "public"},
disable_event_tracking=True,
),
)


operator_args = {
"install_deps": True, # install any necessary dependencies before running any dbt command
"execution_timeout": timedelta(seconds=120),
}

# Currently airflow dags test ignores priority_weight and weight_rule, for this reason, we're setting the following in the CI only:
if os.getenv("CI"):
operator_args["trigger_rule"] = "all_success"


# [START example_watcher]
example_watcher = DbtDag(
# dbt/cosmos-specific parameters
execution_config=ExecutionConfig(
execution_mode=ExecutionMode.WATCHER,
),
project_config=ProjectConfig(DBT_PROJECT_PATH),
profile_config=profile_config,
operator_args=operator_args,
# normal dag parameters
schedule="@daily",
start_date=datetime(2023, 1, 1),
catchup=False,
dag_id="example_watcher",
default_args={"retries": 0},
)
# [END example_watcher]
2 changes: 1 addition & 1 deletion tests/dbt/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -1908,7 +1908,7 @@ def test_save_dbt_ls_cache(mock_variable_set, mock_datetime, tmp_dbt_project_dir
assert hash_args == "d41d8cd98f00b204e9800998ecf8427e"
if sys.platform == "darwin":
# We faced inconsistent hashing versions depending on the version of MacOS/Linux - the following line aims to address these.
assert hash_dir in ("481324dabe926f5cf6352b05e5ebe5d7", "60c08a4730a39d03d89f0f87a8ff3931")
assert hash_dir in ("7f64aab068fb7fcf912765605210bf02", "60c08a4730a39d03d89f0f87a8ff3931")
else:
assert hash_dir == "60c08a4730a39d03d89f0f87a8ff3931"

Expand Down
Loading