Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions cosmos/airflow/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from airflow.utils.task_group import TaskGroup

from cosmos import settings
from cosmos.config import RenderConfig
from cosmos.config import ExecutionConfig, RenderConfig
from cosmos.constants import (
DBT_SETUP_ASYNC_TASK_ID,
DBT_TEARDOWN_ASYNC_TASK_ID,
Expand Down Expand Up @@ -803,6 +803,7 @@ def build_airflow_graph( # noqa: C901 TODO: https://github.com/astronomer/astro
task_group: TaskGroup | None = None,
on_warning_callback: Callable[..., Any] | None = None, # argument specific to the DBT test command
async_py_requirements: list[str] | None = None,
execution_config: ExecutionConfig | None = None,
) -> dict[str, Union[TaskGroup, BaseOperator]]:
"""
Instantiate dbt `nodes` as Airflow tasks within the given `task_group` (optional) or `dag` (mandatory).
Expand Down Expand Up @@ -912,9 +913,10 @@ def build_airflow_graph( # noqa: C901 TODO: https://github.com/astronomer/astro
create_airflow_task_dependencies(nodes, tasks_map)

if execution_mode == ExecutionMode.WATCHER:
setup_operator_args = getattr(execution_config, "setup_operator_args", None) or {}
_add_producer_watcher_and_dependencies(
dag=dag,
task_args=task_args,
task_args={**task_args, **setup_operator_args},
tasks_map=tasks_map,
task_group=task_group,
render_config=render_config,
Expand Down
2 changes: 2 additions & 0 deletions cosmos/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,7 @@ class ExecutionConfig:
should be used for execution when execution mode is set to `ExecutionMode.VIRTUALENV`
:param async_py_requirements: A list of Python packages to install when `ExecutionMode.AIRFLOW_ASYNC`(Experimental) is used. This parameter is required only if both `enable_setup_async_task` and `enable_teardown_async_task` are set to `True`.
Example: `["dbt-postgres==1.5.0"]`
param setup_operator_args: A dictionary of producer operator parameters. These will override the values supplied in operator_args for producer operator.
"""

execution_mode: ExecutionMode = ExecutionMode.LOCAL
Expand All @@ -434,6 +435,7 @@ class ExecutionConfig:

project_path: Path | None = field(init=False)
async_py_requirements: list[str] | None = None
setup_operator_args: dict[str, Any] | None = None

def __post_init__(self, dbt_project_path: str | Path | None) -> None:
if self.invocation_mode and self.execution_mode not in (
Expand Down
1 change: 1 addition & 0 deletions cosmos/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,7 @@ def __init__(
on_warning_callback=on_warning_callback,
render_config=render_config,
async_py_requirements=execution_config.async_py_requirements,
execution_config=execution_config,
)

current_time = time.perf_counter()
Expand Down
40 changes: 38 additions & 2 deletions tests/operators/test_watcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from airflow.utils.state import DagRunState
from packaging.version import Version

from cosmos import DbtDag, ExecutionConfig, ProfileConfig, ProjectConfig, RenderConfig
from cosmos import DbtDag, ExecutionConfig, ProfileConfig, ProjectConfig, RenderConfig, TestBehavior
from cosmos._triggers.watcher import WatcherTrigger
from cosmos.config import InvocationMode
from cosmos.constants import ExecutionMode
Expand All @@ -28,7 +28,7 @@
DbtSeedWatcherOperator,
DbtTestWatcherOperator,
)
from cosmos.profiles import PostgresUserPasswordProfileMapping
from cosmos.profiles import PostgresUserPasswordProfileMapping, get_automatic_profile_mapping
from tests.utils import AIRFLOW_VERSION, new_test_dag

DBT_PROJECT_PATH = Path(__file__).parent.parent.parent / "dev/dags/dbt/jaffle_shop"
Expand Down Expand Up @@ -974,3 +974,39 @@ def test_dbt_task_group_with_watcher_has_correct_dbt_cmd():
# Verify the command was built correctly
assert full_cmd[1] == "build" # dbt build command
assert "--full-refresh" in full_cmd


@pytest.mark.integration
def test_sensor_and_producer_different_param_values(mock_bigquery_conn):
profile_mapping = get_automatic_profile_mapping(mock_bigquery_conn.conn_id, {})
_profile_config = ProfileConfig(
profile_name="airflow_db",
target_name="bq",
profile_mapping=profile_mapping,
)
dbt_project_path = Path(__file__).parent.parent.parent / "dev/dags/dbt"

dag = DbtDag(
project_config=ProjectConfig(dbt_project_path=dbt_project_path / "jaffle_shop"),
profile_config=_profile_config,
operator_args={
"install_deps": True,
"full_refresh": True,
"deferrable": False,
"execution_timeout": timedelta(seconds=1),
},
render_config=RenderConfig(test_behavior=TestBehavior.NONE),
execution_config=ExecutionConfig(
execution_mode=ExecutionMode.WATCHER, setup_operator_args={"execution_timeout": timedelta(seconds=2)}
),
schedule="@daily",
start_date=datetime(2025, 1, 1),
catchup=False,
dag_id="test_sensor_args_import",
)

for task in dag.tasks_map.values():
if isinstance(task, DbtProducerWatcherOperator):
assert task.execution_timeout == timedelta(seconds=2)
else:
assert task.execution_timeout == timedelta(seconds=1)