diff --git a/cosmos/cache.py b/cosmos/cache.py index ed17725bd8..6278b8cc82 100644 --- a/cosmos/cache.py +++ b/cosmos/cache.py @@ -10,6 +10,7 @@ from collections import defaultdict from datetime import datetime, timedelta, timezone from pathlib import Path +from typing import TYPE_CHECKING import msgpack import yaml @@ -22,6 +23,17 @@ from sqlalchemy.orm import Session from cosmos import settings + +if TYPE_CHECKING: + try: + from airflow.sdk import ObjectStoragePath + except ImportError: + try: + from airflow.io.path import ObjectStoragePath + except ImportError: + pass + except ImportError: + pass from cosmos.constants import ( DBT_MANIFEST_FILE_NAME, DBT_TARGET_DIR_NAME, @@ -48,12 +60,12 @@ VAR_KEY_CACHE_PREFIX = "cosmos_cache__" -def _configure_remote_cache_dir() -> Path | None: +def _configure_remote_cache_dir() -> Path | ObjectStoragePath | None: """Configure the remote cache dir if it is provided.""" if not settings_remote_cache_dir: return None - _configured_cache_dir: Path | None = None + _configured_cache_dir: Path | ObjectStoragePath | None = None cache_dir_str = str(settings_remote_cache_dir) diff --git a/cosmos/config.py b/cosmos/config.py index e1a4a0c353..5f4d540d9c 100644 --- a/cosmos/config.py +++ b/cosmos/config.py @@ -8,12 +8,18 @@ import warnings from dataclasses import InitVar, dataclass, field from pathlib import Path -from typing import Any, Callable, Iterator +from typing import TYPE_CHECKING, Any, Callable, Iterator import yaml from airflow.version import version as airflow_version from cosmos import settings + +if TYPE_CHECKING: + try: + from airflow.io.path import ObjectStoragePath + except ImportError: + pass from cosmos.cache import create_cache_profile, get_cached_profile, is_profile_cache_enabled from cosmos.constants import ( DEFAULT_PROFILES_FILE_NAME, @@ -173,7 +179,7 @@ class ProjectConfig: dbt_project_path: Path | None = None install_dbt_deps: bool = True copy_dbt_packages: bool = settings.default_copy_dbt_packages - manifest_path: Path | None = None + manifest_path: Path | ObjectStoragePath | None = None models_path: Path | None = None seeds_path: Path | None = None snapshots_path: Path | None = None @@ -250,7 +256,7 @@ def validate_project(self) -> None: If the project path is not provided, we have a scenario 2 """ - mandatory_paths = {} + mandatory_paths: dict[str, Path | ObjectStoragePath | None] = {} # We validate the existence of paths added to the `mandatory_paths` map by calling the `exists()` method on each # one. Starting with Cosmos 1.6.0, if the Airflow version is `>= 2.8.0` and a `manifest_path` is provided, we # cast it to an `airflow.io.path.ObjectStoragePath` instance during `ProjectConfig` initialisation, and it @@ -259,10 +265,12 @@ def validate_project(self) -> None: # map works correctly for all paths, thereby validating the project. if self.dbt_project_path: project_yml_path = self.dbt_project_path / "dbt_project.yml" - mandatory_paths = { - "dbt_project.yml": Path(project_yml_path) if project_yml_path else None, - "models directory ": Path(self.models_path) if self.models_path else None, - } + mandatory_paths.update( + { + "dbt_project.yml": Path(project_yml_path) if project_yml_path else None, + "models directory ": Path(self.models_path) if self.models_path else None, + } + ) if self.manifest_path: mandatory_paths["manifest"] = self.manifest_path diff --git a/cosmos/dbt/graph.py b/cosmos/dbt/graph.py index ac17a02692..813da01951 100644 --- a/cosmos/dbt/graph.py +++ b/cosmos/dbt/graph.py @@ -18,6 +18,12 @@ from airflow.models import Variable +if TYPE_CHECKING: + try: + from airflow.io.path import ObjectStoragePath + except ImportError: + pass + import cosmos.dbt.runner as dbt_runner from cosmos import cache, settings from cosmos.cache import ( @@ -477,7 +483,7 @@ def save_dbt_ls_cache(self, dbt_ls_output: str) -> None: else: Variable.set(self.dbt_ls_cache_key, cache_dict, serialize_json=True) - def _get_dbt_ls_remote_cache(self, remote_cache_dir: Path) -> dict[str, str]: + def _get_dbt_ls_remote_cache(self, remote_cache_dir: Path | ObjectStoragePath) -> dict[str, str]: """Loads the remote cache for dbt ls.""" cache_dict: dict[str, str] = {} remote_cache_key_path = remote_cache_dir / self.dbt_ls_cache_key / "dbt_ls_cache.json" diff --git a/cosmos/operators/local.py b/cosmos/operators/local.py index a68cea8ab8..a8a41a6edc 100644 --- a/cosmos/operators/local.py +++ b/cosmos/operators/local.py @@ -26,6 +26,11 @@ from airflow.sdk.definitions.context import Context except ImportError: from airflow.utils.context import Context # type: ignore[attr-defined] + + try: + from airflow.io.path import ObjectStoragePath + except ImportError: + pass from airflow.version import version as airflow_version from attrs import define from packaging.version import Version @@ -289,7 +294,7 @@ def store_compiled_sql(self, tmp_project_dir: str, context: Context) -> None: self.compiled_sql = self.compiled_sql.strip() @staticmethod - def _configure_remote_target_path() -> tuple[Path, str] | tuple[None, None]: + def _configure_remote_target_path() -> tuple[Path | ObjectStoragePath, str] | tuple[None, None]: """Configure the remote target path if it is provided.""" if not remote_target_path: return None, None @@ -325,7 +330,7 @@ def _configure_remote_target_path() -> tuple[Path, str] | tuple[None, None]: return _configured_target_path, remote_conn_id def _construct_dest_file_path( - self, dest_target_dir: Path, file_path: str, source_compiled_dir: Path, resource_type: str + self, dest_target_dir: Path | ObjectStoragePath, file_path: str, source_compiled_dir: Path, resource_type: str ) -> str: """ Construct the destination path for the compiled SQL files to be uploaded to the remote store.