Skip to content
16 changes: 14 additions & 2 deletions cosmos/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from collections import defaultdict
from datetime import datetime, timedelta, timezone
from pathlib import Path
from typing import TYPE_CHECKING

import msgpack
import yaml
Expand All @@ -22,6 +23,17 @@
from sqlalchemy.orm import Session

from cosmos import settings

if TYPE_CHECKING:
try:
from airflow.sdk import ObjectStoragePath
except ImportError:
try:
from airflow.io.path import ObjectStoragePath
except ImportError:
pass
except ImportError:
pass
from cosmos.constants import (
DBT_MANIFEST_FILE_NAME,
DBT_TARGET_DIR_NAME,
Expand All @@ -48,12 +60,12 @@
VAR_KEY_CACHE_PREFIX = "cosmos_cache__"


def _configure_remote_cache_dir() -> Path | None:
def _configure_remote_cache_dir() -> Path | ObjectStoragePath | None:
"""Configure the remote cache dir if it is provided."""
if not settings_remote_cache_dir:
return None

_configured_cache_dir: Path | None = None
_configured_cache_dir: Path | ObjectStoragePath | None = None

cache_dir_str = str(settings_remote_cache_dir)

Expand Down
22 changes: 15 additions & 7 deletions cosmos/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,18 @@
import warnings
from dataclasses import InitVar, dataclass, field
from pathlib import Path
from typing import Any, Callable, Iterator
from typing import TYPE_CHECKING, Any, Callable, Iterator

import yaml
from airflow.version import version as airflow_version

from cosmos import settings

if TYPE_CHECKING:
try:
from airflow.io.path import ObjectStoragePath
Comment thread
corsettigyg marked this conversation as resolved.
except ImportError:
pass
from cosmos.cache import create_cache_profile, get_cached_profile, is_profile_cache_enabled
from cosmos.constants import (
DEFAULT_PROFILES_FILE_NAME,
Expand Down Expand Up @@ -173,7 +179,7 @@ class ProjectConfig:
dbt_project_path: Path | None = None
install_dbt_deps: bool = True
copy_dbt_packages: bool = settings.default_copy_dbt_packages
manifest_path: Path | None = None
manifest_path: Path | ObjectStoragePath | None = None
models_path: Path | None = None
seeds_path: Path | None = None
snapshots_path: Path | None = None
Expand Down Expand Up @@ -250,7 +256,7 @@ def validate_project(self) -> None:
If the project path is not provided, we have a scenario 2
"""

mandatory_paths = {}
mandatory_paths: dict[str, Path | ObjectStoragePath | None] = {}
# We validate the existence of paths added to the `mandatory_paths` map by calling the `exists()` method on each
# one. Starting with Cosmos 1.6.0, if the Airflow version is `>= 2.8.0` and a `manifest_path` is provided, we
# cast it to an `airflow.io.path.ObjectStoragePath` instance during `ProjectConfig` initialisation, and it
Expand All @@ -259,10 +265,12 @@ def validate_project(self) -> None:
# map works correctly for all paths, thereby validating the project.
if self.dbt_project_path:
project_yml_path = self.dbt_project_path / "dbt_project.yml"
mandatory_paths = {
"dbt_project.yml": Path(project_yml_path) if project_yml_path else None,
"models directory ": Path(self.models_path) if self.models_path else None,
}
mandatory_paths.update(
{
"dbt_project.yml": Path(project_yml_path) if project_yml_path else None,
"models directory ": Path(self.models_path) if self.models_path else None,
}
)
if self.manifest_path:
mandatory_paths["manifest"] = self.manifest_path

Expand Down
8 changes: 7 additions & 1 deletion cosmos/dbt/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,12 @@

from airflow.models import Variable

if TYPE_CHECKING:
try:
from airflow.io.path import ObjectStoragePath
except ImportError:
pass

import cosmos.dbt.runner as dbt_runner
from cosmos import cache, settings
from cosmos.cache import (
Expand Down Expand Up @@ -477,7 +483,7 @@ def save_dbt_ls_cache(self, dbt_ls_output: str) -> None:
else:
Variable.set(self.dbt_ls_cache_key, cache_dict, serialize_json=True)

def _get_dbt_ls_remote_cache(self, remote_cache_dir: Path) -> dict[str, str]:
def _get_dbt_ls_remote_cache(self, remote_cache_dir: Path | ObjectStoragePath) -> dict[str, str]:
"""Loads the remote cache for dbt ls."""
cache_dict: dict[str, str] = {}
remote_cache_key_path = remote_cache_dir / self.dbt_ls_cache_key / "dbt_ls_cache.json"
Expand Down
9 changes: 7 additions & 2 deletions cosmos/operators/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@
from airflow.sdk.definitions.context import Context
except ImportError:
from airflow.utils.context import Context # type: ignore[attr-defined]

try:
from airflow.io.path import ObjectStoragePath
except ImportError:
pass
from airflow.version import version as airflow_version
from attrs import define
from packaging.version import Version
Expand Down Expand Up @@ -289,7 +294,7 @@ def store_compiled_sql(self, tmp_project_dir: str, context: Context) -> None:
self.compiled_sql = self.compiled_sql.strip()

@staticmethod
def _configure_remote_target_path() -> tuple[Path, str] | tuple[None, None]:
def _configure_remote_target_path() -> tuple[Path | ObjectStoragePath, str] | tuple[None, None]:
"""Configure the remote target path if it is provided."""
if not remote_target_path:
return None, None
Expand Down Expand Up @@ -325,7 +330,7 @@ def _configure_remote_target_path() -> tuple[Path, str] | tuple[None, None]:
return _configured_target_path, remote_conn_id

def _construct_dest_file_path(
self, dest_target_dir: Path, file_path: str, source_compiled_dir: Path, resource_type: str
self, dest_target_dir: Path | ObjectStoragePath, file_path: str, source_compiled_dir: Path, resource_type: str
) -> str:
"""
Construct the destination path for the compiled SQL files to be uploaded to the remote store.
Expand Down