diff --git a/cosmos/airflow/graph.py b/cosmos/airflow/graph.py index 32359d5c93..8b2b9652db 100644 --- a/cosmos/airflow/graph.py +++ b/cosmos/airflow/graph.py @@ -19,7 +19,7 @@ from airflow.utils.task_group import TaskGroup from cosmos import settings -from cosmos.config import RenderConfig +from cosmos.config import ExecutionConfig, RenderConfig from cosmos.constants import ( DBT_SETUP_ASYNC_TASK_ID, DBT_TEARDOWN_ASYNC_TASK_ID, @@ -803,6 +803,7 @@ def build_airflow_graph( # noqa: C901 TODO: https://github.com/astronomer/astro task_group: TaskGroup | None = None, on_warning_callback: Callable[..., Any] | None = None, # argument specific to the DBT test command async_py_requirements: list[str] | None = None, + execution_config: ExecutionConfig | None = None, ) -> dict[str, Union[TaskGroup, BaseOperator]]: """ Instantiate dbt `nodes` as Airflow tasks within the given `task_group` (optional) or `dag` (mandatory). @@ -912,9 +913,10 @@ def build_airflow_graph( # noqa: C901 TODO: https://github.com/astronomer/astro create_airflow_task_dependencies(nodes, tasks_map) if execution_mode == ExecutionMode.WATCHER: + setup_operator_args = getattr(execution_config, "setup_operator_args", None) or {} _add_producer_watcher_and_dependencies( dag=dag, - task_args=task_args, + task_args={**task_args, **setup_operator_args}, tasks_map=tasks_map, task_group=task_group, render_config=render_config, diff --git a/cosmos/config.py b/cosmos/config.py index c34c4c7bd3..f8433e4f14 100644 --- a/cosmos/config.py +++ b/cosmos/config.py @@ -422,6 +422,7 @@ class ExecutionConfig: should be used for execution when execution mode is set to `ExecutionMode.VIRTUALENV` :param async_py_requirements: A list of Python packages to install when `ExecutionMode.AIRFLOW_ASYNC`(Experimental) is used. This parameter is required only if both `enable_setup_async_task` and `enable_teardown_async_task` are set to `True`. Example: `["dbt-postgres==1.5.0"]` + param setup_operator_args: A dictionary of producer operator parameters. These will override the values supplied in operator_args for producer operator. """ execution_mode: ExecutionMode = ExecutionMode.LOCAL @@ -434,6 +435,7 @@ class ExecutionConfig: project_path: Path | None = field(init=False) async_py_requirements: list[str] | None = None + setup_operator_args: dict[str, Any] | None = None def __post_init__(self, dbt_project_path: str | Path | None) -> None: if self.invocation_mode and self.execution_mode not in ( diff --git a/cosmos/converter.py b/cosmos/converter.py index 803ad46f7b..b56c91c668 100644 --- a/cosmos/converter.py +++ b/cosmos/converter.py @@ -336,6 +336,7 @@ def __init__( on_warning_callback=on_warning_callback, render_config=render_config, async_py_requirements=execution_config.async_py_requirements, + execution_config=execution_config, ) current_time = time.perf_counter() diff --git a/tests/operators/test_watcher.py b/tests/operators/test_watcher.py index bb25ed3792..1cf1ba0c2e 100644 --- a/tests/operators/test_watcher.py +++ b/tests/operators/test_watcher.py @@ -15,7 +15,7 @@ from airflow.utils.state import DagRunState from packaging.version import Version -from cosmos import DbtDag, ExecutionConfig, ProfileConfig, ProjectConfig, RenderConfig +from cosmos import DbtDag, ExecutionConfig, ProfileConfig, ProjectConfig, RenderConfig, TestBehavior from cosmos._triggers.watcher import WatcherTrigger from cosmos.config import InvocationMode from cosmos.constants import ExecutionMode @@ -28,7 +28,7 @@ DbtSeedWatcherOperator, DbtTestWatcherOperator, ) -from cosmos.profiles import PostgresUserPasswordProfileMapping +from cosmos.profiles import PostgresUserPasswordProfileMapping, get_automatic_profile_mapping from tests.utils import AIRFLOW_VERSION, new_test_dag DBT_PROJECT_PATH = Path(__file__).parent.parent.parent / "dev/dags/dbt/jaffle_shop" @@ -974,3 +974,39 @@ def test_dbt_task_group_with_watcher_has_correct_dbt_cmd(): # Verify the command was built correctly assert full_cmd[1] == "build" # dbt build command assert "--full-refresh" in full_cmd + + +@pytest.mark.integration +def test_sensor_and_producer_different_param_values(mock_bigquery_conn): + profile_mapping = get_automatic_profile_mapping(mock_bigquery_conn.conn_id, {}) + _profile_config = ProfileConfig( + profile_name="airflow_db", + target_name="bq", + profile_mapping=profile_mapping, + ) + dbt_project_path = Path(__file__).parent.parent.parent / "dev/dags/dbt" + + dag = DbtDag( + project_config=ProjectConfig(dbt_project_path=dbt_project_path / "jaffle_shop"), + profile_config=_profile_config, + operator_args={ + "install_deps": True, + "full_refresh": True, + "deferrable": False, + "execution_timeout": timedelta(seconds=1), + }, + render_config=RenderConfig(test_behavior=TestBehavior.NONE), + execution_config=ExecutionConfig( + execution_mode=ExecutionMode.WATCHER, setup_operator_args={"execution_timeout": timedelta(seconds=2)} + ), + schedule="@daily", + start_date=datetime(2025, 1, 1), + catchup=False, + dag_id="test_sensor_args_import", + ) + + for task in dag.tasks_map.values(): + if isinstance(task, DbtProducerWatcherOperator): + assert task.execution_timeout == timedelta(seconds=2) + else: + assert task.execution_timeout == timedelta(seconds=1)