diff --git a/cosmos/airflow/graph.py b/cosmos/airflow/graph.py index af854d4f50..4286a7a081 100644 --- a/cosmos/airflow/graph.py +++ b/cosmos/airflow/graph.py @@ -99,7 +99,10 @@ def create_test_task_metadata( def create_task_metadata( - node: DbtNode, execution_mode: ExecutionMode, args: dict[str, Any], use_task_group: bool = False + node: DbtNode, + execution_mode: ExecutionMode, + args: dict[str, Any], + use_task_group: bool = False, ) -> TaskMetadata | None: """ Create the metadata that will be used to instantiate the Airflow Task used to run the Dbt node. @@ -156,6 +159,7 @@ def generate_task_or_group( test_behavior: TestBehavior, test_indirect_selection: TestIndirectSelection, on_warning_callback: Callable[..., Any] | None, + node_config: dict[str, Any] | None = None, **kwargs: Any, ) -> BaseOperator | TaskGroup | None: task_or_group: BaseOperator | TaskGroup | None = None @@ -176,7 +180,7 @@ def generate_task_or_group( if task_meta and node.resource_type != DbtResourceType.TEST: if use_task_group: with TaskGroup(dag=dag, group_id=node.name, parent_group=task_group) as model_task_group: - task = create_airflow_task(task_meta, dag, task_group=model_task_group) + task = create_airflow_task(task_meta, dag, task_group=model_task_group, extra_context=node_config) test_meta = create_test_task_metadata( "test", execution_mode, @@ -185,11 +189,12 @@ def generate_task_or_group( node=node, on_warning_callback=on_warning_callback, ) - test_task = create_airflow_task(test_meta, dag, task_group=model_task_group) + test_task = create_airflow_task(test_meta, dag, task_group=model_task_group, extra_context=node_config) task >> test_task task_or_group = model_task_group else: task_or_group = create_airflow_task(task_meta, dag, task_group=task_group) + return task_or_group @@ -251,6 +256,7 @@ def build_airflow_graph( test_indirect_selection=test_indirect_selection, on_warning_callback=on_warning_callback, node=node, + node_config=node.config, ) if task_or_group is not None: logger.debug(f"Conversion of <{node.unique_id}> was successful!") diff --git a/cosmos/core/airflow.py b/cosmos/core/airflow.py index 7c5dee3281..ffb6137a19 100644 --- a/cosmos/core/airflow.py +++ b/cosmos/core/airflow.py @@ -6,12 +6,14 @@ from cosmos.core.graph.entities import Task from cosmos.log import get_logger - +from typing import Any logger = get_logger(__name__) -def get_airflow_task(task: Task, dag: DAG, task_group: "TaskGroup | None" = None) -> BaseOperator: +def get_airflow_task( + task: Task, dag: DAG, task_group: "TaskGroup | None" = None, extra_context: dict[str, Any] | None = None +) -> BaseOperator: """ Get the Airflow Operator class for a Task. @@ -30,6 +32,7 @@ def get_airflow_task(task: Task, dag: DAG, task_group: "TaskGroup | None" = None task_id=task.id, dag=dag, task_group=task_group, + extra_context=extra_context, **task.arguments, ) diff --git a/cosmos/operators/base.py b/cosmos/operators/base.py index 6d276013d5..3aa5dae08c 100644 --- a/cosmos/operators/base.py +++ b/cosmos/operators/base.py @@ -57,6 +57,7 @@ class DbtBaseOperator(BaseOperator): (i.e. /home/astro/.pyenv/versions/dbt_venv/bin/dbt) :param dbt_cmd_flags: List of flags to pass to dbt command :param dbt_cmd_global_flags: List of dbt global flags to be passed to the dbt command + :param extra_context: A dictionary of values to add to the Airflow Task context """ template_fields: Sequence[str] = ("env", "vars") @@ -105,6 +106,7 @@ def __init__( dbt_executable_path: str = get_system_dbt(), dbt_cmd_flags: list[str] | None = None, dbt_cmd_global_flags: list[str] | None = None, + extra_context: dict[str, Any] | None = None, **kwargs: Any, ) -> None: self.project_dir = project_dir @@ -132,6 +134,7 @@ def __init__( self.dbt_executable_path = dbt_executable_path self.dbt_cmd_flags = dbt_cmd_flags self.dbt_cmd_global_flags = dbt_cmd_global_flags or [] + self.extra_context = extra_context or {} super().__init__(**kwargs) def get_env(self, context: Context) -> dict[str, str | bytes | os.PathLike[Any]]: @@ -231,3 +234,9 @@ def build_cmd( env = self.get_env(context) return dbt_cmd, env + + def pre_execute(self, context: Any) -> None: + if self.extra_context: + logger.info("Extra context passed to operator, injecting into TaskInstance") + context["model_config"] = self.extra_context + return super().pre_execute(context) diff --git a/cosmos/operators/local.py b/cosmos/operators/local.py index 6eea764ad1..0e7d3684c5 100644 --- a/cosmos/operators/local.py +++ b/cosmos/operators/local.py @@ -236,9 +236,7 @@ def run_command( ) if is_openlineage_available: self.calculate_openlineage_events_completes(env, Path(tmp_project_dir)) - context[ - "task_instance" - ].openlineage_events_completes = self.openlineage_events_completes # type: ignore + context["task_instance"].openlineage_events_completes = self.openlineage_events_completes # type: ignore if self.emit_datasets: inlets = self.get_datasets("inputs")