diff --git a/cosmos/constants.py b/cosmos/constants.py index b0a428d720..3b2c583637 100644 --- a/cosmos/constants.py +++ b/cosmos/constants.py @@ -167,3 +167,5 @@ def _missing_value_(cls, value): # type: ignore TELEMETRY_URL = "https://astronomer.gateway.scarf.sh/astronomer-cosmos/{telemetry_version}/{cosmos_version}/{airflow_version}/{python_version}/{platform_system}/{platform_machine}/{event_type}/{status}/{dag_hash}/{task_count}/{cosmos_task_count}" TELEMETRY_VERSION = "v1" TELEMETRY_TIMEOUT = 1.0 + +_AIRFLOW3_VERSION = Version("3.0.0a1") diff --git a/cosmos/operators/local.py b/cosmos/operators/local.py index 7559552f30..034caa9501 100644 --- a/cosmos/operators/local.py +++ b/cosmos/operators/local.py @@ -20,7 +20,6 @@ from airflow.models import BaseOperator from airflow.models.taskinstance import TaskInstance from airflow.utils.context import Context -from airflow.utils.session import NEW_SESSION, create_session, provide_session from airflow.version import version as airflow_version from attr import define from packaging.version import Version @@ -32,7 +31,7 @@ _get_latest_cached_package_lockfile, is_cache_package_lockfile_enabled, ) -from cosmos.constants import FILE_SCHEME_AIRFLOW_DEFAULT_CONN_ID_MAP, InvocationMode +from cosmos.constants import _AIRFLOW3_VERSION, FILE_SCHEME_AIRFLOW_DEFAULT_CONN_ID_MAP, InvocationMode from cosmos.dataset import get_dataset_alias_name from cosmos.dbt.project import get_partial_parse_path, has_non_empty_dependencies_file from cosmos.exceptions import AirflowCompatibilityError, CosmosDbtRunError, CosmosValueError @@ -226,54 +225,63 @@ def handle_exception_dbt_runner(self, result: dbtRunnerResult) -> None: """dbtRunnerResult has an attribute `success` that is False if the command failed.""" return dbt_runner.handle_exception_if_needed(result) - @provide_session - def store_compiled_sql(self, tmp_project_dir: str, context: Context, session: Session = NEW_SESSION) -> None: + def store_compiled_sql(self, tmp_project_dir: str, context: Context) -> None: """ Takes the compiled SQL files from the dbt run and stores them in the compiled_sql rendered template. Gets called after every dbt run. """ - if not self.should_store_compiled_sql: - return - compiled_queries = {} - # dbt compiles sql files and stores them in the target directory - for folder_path, _, file_paths in os.walk(os.path.join(tmp_project_dir, "target")): - for file_path in file_paths: - if not file_path.endswith(".sql"): - continue + from airflow.utils.session import NEW_SESSION, provide_session - compiled_sql_path = Path(os.path.join(folder_path, file_path)) - compiled_sql = compiled_sql_path.read_text(encoding="utf-8") + @provide_session + def _store_compiled_sql(session: Session = NEW_SESSION) -> None: - relative_path = str(compiled_sql_path.relative_to(tmp_project_dir)) - compiled_queries[relative_path] = compiled_sql.strip() + if not self.should_store_compiled_sql: + return - for name, query in compiled_queries.items(): - self.compiled_sql += f"-- {name}\n{query}\n\n" + compiled_queries = {} + # dbt compiles sql files and stores them in the target directory + for folder_path, _, file_paths in os.walk(os.path.join(tmp_project_dir, "target")): + for file_path in file_paths: + if not file_path.endswith(".sql"): + continue - self.compiled_sql = self.compiled_sql.strip() + compiled_sql_path = Path(os.path.join(folder_path, file_path)) + compiled_sql = compiled_sql_path.read_text(encoding="utf-8") - # need to refresh the rendered task field record in the db because Airflow only does this - # before executing the task, not after - from airflow.models.renderedtifields import RenderedTaskInstanceFields + relative_path = str(compiled_sql_path.relative_to(tmp_project_dir)) + compiled_queries[relative_path] = compiled_sql.strip() - ti = context["ti"] + for name, query in compiled_queries.items(): + self.compiled_sql += f"-- {name}\n{query}\n\n" - if isinstance(ti, TaskInstance): # verifies ti is a TaskInstance in order to access and use the "task" field - if TYPE_CHECKING: - assert ti.task is not None - ti.task.template_fields = self.template_fields - rtif = RenderedTaskInstanceFields(ti, render_templates=False) + self.compiled_sql = self.compiled_sql.strip() - # delete the old records - session.query(RenderedTaskInstanceFields).filter( - RenderedTaskInstanceFields.dag_id == self.dag_id, # type: ignore[attr-defined] - RenderedTaskInstanceFields.task_id == self.task_id, - RenderedTaskInstanceFields.run_id == ti.run_id, - ).delete() - session.add(rtif) - else: - self.log.info("Warning: ti is of type TaskInstancePydantic. Cannot update template_fields.") + # need to refresh the rendered task field record in the db because Airflow only does this + # before executing the task, not after + from airflow.models.renderedtifields import RenderedTaskInstanceFields + + ti = context["ti"] + + if isinstance( + ti, TaskInstance + ): # verifies ti is a TaskInstance in order to access and use the "task" field + if TYPE_CHECKING: + assert ti.task is not None + ti.task.template_fields = self.template_fields + rtif = RenderedTaskInstanceFields(ti, render_templates=False) + + # delete the old records + session.query(RenderedTaskInstanceFields).filter( + RenderedTaskInstanceFields.dag_id == self.dag_id, # type: ignore[attr-defined] + RenderedTaskInstanceFields.task_id == self.task_id, + RenderedTaskInstanceFields.run_id == ti.run_id, + ).delete() + session.add(rtif) + else: + self.log.info("Warning: ti is of type TaskInstancePydantic. Cannot update template_fields.") + + _store_compiled_sql() @staticmethod def _configure_remote_target_path() -> tuple[Path, str] | tuple[None, None]: @@ -353,28 +361,33 @@ def _delete_sql_files(self, tmp_project_dir: Path, resource_type: str) -> None: dest_object_storage_path.unlink() self.log.debug("Deleted %s to %s", file_path, dest_object_storage_path) - @provide_session - def store_freshness_json(self, tmp_project_dir: str, context: Context, session: Session = NEW_SESSION) -> None: + def store_freshness_json(self, tmp_project_dir: str, context: Context) -> None: """ Takes the compiled sources.json file from the dbt source freshness and stores it in the freshness rendered template. Gets called after every dbt run / source freshness. """ - if not self.should_store_compiled_sql: - return + from airflow.utils.session import NEW_SESSION, provide_session - sources_json_path = Path(os.path.join(tmp_project_dir, "target", "sources.json")) + @provide_session + def _store_freshness_json(session: Session = NEW_SESSION) -> None: + if not self.should_store_compiled_sql: + return - if sources_json_path.exists(): - sources_json_content = sources_json_path.read_text(encoding="utf-8").strip() + sources_json_path = Path(os.path.join(tmp_project_dir, "target", "sources.json")) - sources_data = json.loads(sources_json_content) + if sources_json_path.exists(): + sources_json_content = sources_json_path.read_text(encoding="utf-8").strip() - formatted_sources_json = json.dumps(sources_data, indent=4) + sources_data = json.loads(sources_json_content) - self.freshness = formatted_sources_json + formatted_sources_json = json.dumps(sources_data, indent=4) - else: - self.freshness = "" + self.freshness = formatted_sources_json + + else: + self.freshness = "" + + _store_freshness_json() def run_subprocess(self, command: list[str], env: dict[str, str], cwd: str) -> FullOutputSubprocessResult: self.log.info("Trying to run the command:\n %s\nFrom %s", command, cwd) @@ -464,7 +477,8 @@ def _handle_datasets(self, context: Context) -> None: outlets = self.get_datasets("outputs") self.log.info("Inlets: %s", inlets) self.log.info("Outlets: %s", outlets) - self.register_dataset(inlets, outlets, context) + if AIRFLOW_VERSION < _AIRFLOW3_VERSION: + self.register_dataset(inlets, outlets, context) def _update_partial_parse_cache(self, tmp_dir_path: Path) -> None: if self.cache_dir is None: @@ -474,8 +488,9 @@ def _update_partial_parse_cache(self, tmp_dir_path: Path) -> None: cache._update_partial_parse_cache(partial_parse_file, self.cache_dir) def _handle_post_execution(self, tmp_project_dir: str, context: Context) -> None: - self.store_freshness_json(tmp_project_dir, context) - self.store_compiled_sql(tmp_project_dir, context) + if AIRFLOW_VERSION < _AIRFLOW3_VERSION: + self.store_freshness_json(tmp_project_dir, context) + self.store_compiled_sql(tmp_project_dir, context) if self.should_upload_compiled_sql: self._upload_sql_files(tmp_project_dir, "compiled") if self.callback: @@ -641,6 +656,8 @@ def register_dataset(self, new_inlets: list[Dataset], new_outlets: list[Dataset] with DatasetAlias: https://github.com/apache/airflow/issues/42495 """ + from airflow.utils.session import create_session + if AIRFLOW_VERSION < Version("2.10") or not settings.enable_dataset_alias: logger.info("Assigning inlets/outlets without DatasetAlias") with create_session() as session: @@ -742,17 +759,24 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: if arg_key in base_operator_args: base_operator_kwargs[arg_key] = arg_value AbstractDbtLocalBase.__init__(self, **abstract_dbt_local_base_kwargs) - if kwargs.get("emit_datasets", True) and settings.enable_dataset_alias and AIRFLOW_VERSION >= Version("2.10"): - from airflow.datasets import DatasetAlias - - # ignoring the type because older versions of Airflow raise the follow error in mypy - # error: Incompatible types in assignment (expression has type "list[DatasetAlias]", target has type "str") - dag_id = kwargs.get("dag") - task_group_id = kwargs.get("task_group") - base_operator_kwargs["outlets"] = [ - DatasetAlias(name=get_dataset_alias_name(dag_id, task_group_id, self.task_id)) - ] # type: ignore - BaseOperator.__init__(self, **base_operator_kwargs) + if AIRFLOW_VERSION < _AIRFLOW3_VERSION: + if ( + kwargs.get("emit_datasets", True) + and settings.enable_dataset_alias + and AIRFLOW_VERSION >= Version("2.10") + ): + from airflow.datasets import DatasetAlias + + # ignoring the type because older versions of Airflow raise the follow error in mypy + # error: Incompatible types in assignment (expression has type "list[DatasetAlias]", target has type "str") + dag_id = kwargs.get("dag") + task_group_id = kwargs.get("task_group") + base_operator_kwargs["outlets"] = [ + DatasetAlias(name=get_dataset_alias_name(dag_id, task_group_id, self.task_id)) + ] # type: ignore + if "task_id" in base_operator_kwargs: + base_operator_kwargs.pop("task_id") + BaseOperator.__init__(self, task_id=self.task_id, **base_operator_kwargs) class DbtBuildLocalOperator(DbtBuildMixin, DbtLocalBaseOperator): diff --git a/scripts/airflow3/dags/example_operators.py b/scripts/airflow3/dags/example_operators.py new file mode 100644 index 0000000000..41fd0e18a3 --- /dev/null +++ b/scripts/airflow3/dags/example_operators.py @@ -0,0 +1,50 @@ +import os +from datetime import datetime +from pathlib import Path + +from airflow import DAG + +from cosmos.config import ProfileConfig +from cosmos.operators.local import DbtCloneLocalOperator, DbtRunLocalOperator, DbtSeedLocalOperator + +DEFAULT_DBT_ROOT_PATH = Path(__file__).parent.parent.parent.parent / "dev/dags/dbt" +DBT_ROOT_PATH = Path(os.getenv("DBT_ROOT_PATH", DEFAULT_DBT_ROOT_PATH)) +DBT_PROJ_DIR = DBT_ROOT_PATH / "jaffle_shop" +DBT_PROFILE_PATH = DBT_PROJ_DIR / "profiles.yml" +DBT_ARTIFACT = DBT_PROJ_DIR / "target" + +profile_config = ProfileConfig( + profile_name="default", + target_name="dev", + profiles_yml_filepath=DBT_PROFILE_PATH, +) + +with DAG("example_operators", start_date=datetime(2024, 1, 1), catchup=False) as dag: + seed_operator = DbtSeedLocalOperator( + profile_config=profile_config, + project_dir=DBT_PROJ_DIR, + task_id="seed", + dbt_cmd_flags=["--select", "raw_customers"], + install_deps=True, + append_env=True, + ) + + clone_operator = DbtCloneLocalOperator( + profile_config=profile_config, + project_dir=DBT_PROJ_DIR, + task_id="clone", + dbt_cmd_flags=["--models", "stg_customers", "--state", DBT_ARTIFACT], + install_deps=True, + append_env=True, + ) + + run_operator = DbtRunLocalOperator( + profile_config=profile_config, + project_dir=DBT_PROJ_DIR, + task_id="run", + dbt_cmd_flags=["--models", "stg_customers"], + install_deps=True, + append_env=True, + ) + + seed_operator >> [run_operator, clone_operator] diff --git a/scripts/airflow3/env.sh b/scripts/airflow3/env.sh index 2d8f259a8e..5f1ecf5791 100644 --- a/scripts/airflow3/env.sh +++ b/scripts/airflow3/env.sh @@ -1,3 +1,8 @@ +#!/bin/bash + +set -e +set -x + AIRFLOW_HOME="$PWD/scripts/airflow3" export AIRFLOW_HOME export AIRFLOW__LOGGING__BASE_LOG_FOLDER="$AIRFLOW_HOME/logs" diff --git a/scripts/airflow3/setup.sh b/scripts/airflow3/setup.sh index 93ac095d3a..315d186598 100644 --- a/scripts/airflow3/setup.sh +++ b/scripts/airflow3/setup.sh @@ -15,7 +15,15 @@ source "$(pwd)/scripts/airflow3/venv/bin/activate" echo "Installing dependencies..." uv pip install -r "$(pwd)/scripts/airflow3/requirements.txt" -# Optional: Install additional dependencies -uv pip install astronomer-cosmos +# Install Cosmos +uv pip install build +rm -rf "$(pwd)/dist/" +python3 -m build +for wheel in "$(pwd)"/dist/*.whl; do + uv pip install "$wheel" +done + +uv pip install dbt-core +uv pip install dbt-postgres echo "UV virtual environment setup and dependencies installed successfully!" diff --git a/tests/operators/test_local.py b/tests/operators/test_local.py index bbafd0fda2..c78478066e 100644 --- a/tests/operators/test_local.py +++ b/tests/operators/test_local.py @@ -1153,7 +1153,7 @@ def test_store_freshness_json(mock_path_class, mock_context, mock_session): expected_freshness = json.dumps({"key": "value"}, indent=4) # Call the method under test - instance.store_freshness_json(tmp_project_dir="/mock/dir", context=mock_context, session=mock_session) + instance.store_freshness_json(tmp_project_dir="/mock/dir", context=mock_context) # Verify the freshness attribute is set correctly assert instance.freshness == expected_freshness @@ -1174,7 +1174,7 @@ def test_store_freshness_json_no_file(mock_path_class, mock_context, mock_sessio mock_sources_json_path.exists.return_value = False # Call the method under test - instance.store_freshness_json(tmp_project_dir="/mock/dir", context=mock_context, session=mock_session) + instance.store_freshness_json(tmp_project_dir="/mock/dir", context=mock_context) # Verify the freshness attribute is set correctly assert instance.freshness == "" @@ -1189,7 +1189,7 @@ def test_store_freshness_not_store_compiled_sql(mock_context, mock_session): ) # Call the method under test - instance.store_freshness_json(tmp_project_dir="/mock/dir", context=mock_context, session=mock_session) + instance.store_freshness_json(tmp_project_dir="/mock/dir", context=mock_context) # Verify the freshness attribute is set correctly assert instance.freshness == ""