Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 20 additions & 2 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
@@ -1,9 +1,27 @@
Changelog
=========

1.4.1 (2024-05-17)
1.5.0a2 (2024-05-23)
--------------------

New Features

* Support for running dbt tasks in AWS EKS in #944 by @VolkerSchiewe
* Support caching at a DbtDag and DbtTaskGroupLevel in #992 by @tatiana (WIP)
- difference from 1.5.0a1: Include timestamp of the DAG in the cache version

Others

* Drop support for Airflow 2.3 in #994 by @pankajkoti
* Update Astro Runtime image in #988 and #989 by @RNHTTR
* Enable ruff F linting in #985 by @pankajastro
* Move Cosmos Airflow configuration to settings.py in #975 by @pankajastro



1.4.1 (2024-05-17)
------------------

Bug fixes

* Fix manifest testing behavior in #955 by @chris-okorodudu
Expand All @@ -20,7 +38,7 @@ Others


1.4.0 (2024-05-13)
--------------------
------------------

Features

Expand Down
2 changes: 1 addition & 1 deletion cosmos/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

Contains dags, task groups, and operators.
"""
__version__ = "1.4.1"
__version__ = "1.5.0a2"


from cosmos.airflow.dag import DbtDag
Expand Down
50 changes: 50 additions & 0 deletions cosmos/airflow/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,23 +4,73 @@

from __future__ import annotations

# import inspect
import pickle
import time
from typing import Any

from airflow.models.dag import DAG

from cosmos import cache
from cosmos.converter import DbtToAirflowConverter, airflow_kwargs, specific_kwargs
from cosmos.log import get_logger

logger = get_logger()


class DbtDag(DAG, DbtToAirflowConverter):
"""
Render a dbt project as an Airflow DAG.
"""

def __new__(cls, *args, **kwargs): # type: ignore
dag_id = kwargs.get("dag_id")
project_config = kwargs.get("project_config")

# When we load a Pickle dump of an instance, __new__ is invoked without kwargs
# In those cases, we should not call __new__ again, otherwise we'll have an infinite recursion
if dag_id is not None and project_config and project_config.dbt_project_path:
cache_id = cache.create_cache_identifier_v2(dag_id, None)
current_version = cache.calculate_current_version(cache_id, project_config.dbt_project_path)
cache_filepath = cache.should_use_cache() and cache.is_project_unmodified(cache_id, current_version)
if cache_filepath:
logger.info(f"Restoring {cls.__name__} {dag_id} from cache {cache_filepath}")
with open(cache_filepath, "rb") as fp:
start_time = time.process_time()
dbt_dag = pickle.load(fp)
elapsed_time = time.process_time() - start_time
logger.info(
f"It took {elapsed_time:.3}s to restore the cached version of the {cls.__name__} {dag_id}"
)
return dbt_dag

instance = DAG.__new__(DAG)
cls.__init__(instance, *args, **kwargs) # type: ignore
return instance

# The __init__ is not called when restoring the cached in __new__
def __init__(
self,
*args: Any,
**kwargs: Any,
) -> None:
start_time = time.process_time()
dag_id = kwargs["dag_id"]
project_config = kwargs.get("project_config")

DAG.__init__(self, *args, **airflow_kwargs(**kwargs))
kwargs["dag"] = self
DbtToAirflowConverter.__init__(self, *args, **specific_kwargs(**kwargs))

elapsed_time = time.process_time() - start_time
logger.info(f"It took {elapsed_time} to create the {self.__class__.__name__} {dag_id} from scratch")

if cache.should_use_cache() and project_config:
cache_id = cache.create_cache_identifier_v2(dag_id, None)
cache_filepath = cache.get_cache_filepath(cache_id)
with open(cache_filepath, "wb") as fp:
pickle.dump(self, fp)
cache_version_filepath = cache.get_cache_version_filepath(cache_id)
current_version = cache.calculate_current_version(cache_id, project_config.dbt_project_path)
cache_version_filepath.write_text(current_version)
logger.info(f"Stored {self.__class__.__name__} {dag_id} cache {cache_filepath}")
49 changes: 49 additions & 0 deletions cosmos/airflow/task_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,25 +4,74 @@

from __future__ import annotations

import pickle
import time
from typing import Any

from airflow.utils.task_group import TaskGroup

from cosmos import cache
from cosmos.converter import DbtToAirflowConverter, airflow_kwargs, specific_kwargs
from cosmos.log import get_logger

logger = get_logger()


class DbtTaskGroup(TaskGroup, DbtToAirflowConverter):
"""
Render a dbt project as an Airflow Task Group.
"""

def __new__(cls, *args, **kwargs): # type: ignore
dag_id = kwargs.get("dag_id")
task_id = kwargs.get("task_id")
project_config = kwargs.get("project_config")

# When we load a Pickle dump of an instance, __new__ is invoked without kwargs
# In those cases, we should not call __new__ again, otherwise we'll have an infinite recursion
if task_id is not None and project_config and project_config.dbt_project_path:
cache_id = cache.create_cache_identifier_v2(dag_id, task_id)
current_version = cache.calculate_current_version(cache_id, project_config.dbt_project_path)
cache_filepath = cache.should_use_cache() and cache.is_project_unmodified(cache_id, current_version)
if cache_filepath:
logger.info(f"Restoring {cls.__name__} {dag_id} from cache {cache_filepath}")
with open(cache_filepath, "rb") as fp:
start_time = time.process_time()
dbt_dag = pickle.load(fp)
elapsed_time = time.process_time() - start_time
logger.info(
f"It took {elapsed_time:.3}s to restore the cached version of the {cls.__name__} {dag_id}"
)
return dbt_dag

instance = TaskGroup.__new__(TaskGroup)
cls.__init__(instance, *args, **kwargs) # type: ignore
return instance

def __init__(
self,
group_id: str = "dbt_task_group",
*args: Any,
**kwargs: Any,
) -> None:
start_time = time.process_time()
kwargs["group_id"] = group_id
dag_id = kwargs.get("dag_id")
project_config = kwargs.get("project_config")

TaskGroup.__init__(self, *args, **airflow_kwargs(**kwargs))
kwargs["task_group"] = self
DbtToAirflowConverter.__init__(self, *args, **specific_kwargs(**kwargs))

elapsed_time = time.process_time() - start_time
logger.info(f"It took {elapsed_time} to create the {self.__class__.__name__} {dag_id} from scratch")

if cache.should_use_cache() and project_config:
cache_id = cache.create_cache_identifier_v2(dag_id, group_id)
cache_filepath = cache.get_cache_filepath(cache_id)
with open(cache_filepath, "wb") as fp:
pickle.dump(self, fp)
cache_version_filepath = cache.get_cache_version_filepath(cache_id)
current_version = cache.calculate_current_version(cache_id, project_config.dbt_project_path)
cache_version_filepath.write_text(current_version)
logger.info(f"Stored {self.__class__.__name__} {dag_id} cache {cache_filepath}")
84 changes: 84 additions & 0 deletions cosmos/cache.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from __future__ import annotations

import functools
import inspect
import shutil
import time
from pathlib import Path

import msgpack
Expand Down Expand Up @@ -171,3 +174,84 @@ def _copy_partial_parse_to_project(partial_parse_filepath: Path, project_path: P

if source_manifest_filepath.exists():
shutil.copy(str(source_manifest_filepath), str(target_manifest_filepath))


# The following methods are being used to cache DbtDag / DbtTaskGroup


# It was considered to create a cache identifier based on the dbt project path, as opposed
# to where it is used in Airflow. However, we could have concurrency issues if the same
# dbt cached directory was being used by different dbt task groups or DAGs within the same
# node. For this reason, as a starting point, the cache is identified by where it is used.
# This can be reviewed in the future.
def create_cache_identifier_v2(dag_id: str | None, task_group_id: str | None) -> str:
# FIXME: To be refactored and merged with _create_cache_identifier
# Missing support to: task_group.group_id
"""
Given a DAG name and a (optional) task_group_name, create the identifier for caching.

:param dag_name: Name of the Cosmos DbtDag being cached
:param task_group_name: (optional) Name of the Cosmos DbtTaskGroup being cached
:return: Unique identifier representing the cache
"""
cache_identifiers_list = []
if task_group_id:
if dag_id is not None:
cache_identifiers_list.append(dag_id)
if task_group_id is not None:
cache_identifiers_list.append(task_group_id)
cache_identifier = "__".join(cache_identifiers_list)
else:
cache_identifier = str(dag_id)

return cache_identifier


@functools.lru_cache
def get_cache_filepath(cache_identifier: str) -> Path:
cache_dir_path = _obtain_cache_dir_path(cache_identifier)
return cache_dir_path / f"{cache_identifier}.pkl"


@functools.lru_cache
def get_cache_version_filepath(cache_identifier: str) -> Path:
return Path(str(get_cache_filepath(cache_identifier)) + ".version")


@functools.lru_cache
def should_use_cache() -> bool:
return settings.enable_cache and settings.experimental_cache


@functools.lru_cache
def calculate_current_version(dag_id: str, project_dir: Path) -> str:
start_time = time.process_time()

# When DAG file was last changed - this is very slow (e.g. 0.6s)
caller_dag_frame = inspect.stack()[2]
caller_dag_filepath = Path(caller_dag_frame.filename)
logger.info(f"The {dag_id} DAG is located in: {caller_dag_filepath}")
dag_last_modified = caller_dag_filepath.stat().st_mtime
mid_time = time.process_time() - start_time
logger.info(f"It took {mid_time:.3}s to calculate the first part of the version")
# dag_last_modified = None

# Combined value for when the dbt project directory files were last modified
# This is fast (e.g. 0.01s for jaffle shop, 0.135s for a 5k models dbt folder)
dbt_combined_last_modified = sum([path.stat().st_mtime for path in project_dir.glob("**/*")])

elapsed_time = time.process_time() - start_time
logger.info(f"It took {elapsed_time:.3}s to calculate the cache version for the {dag_id}")
return f"{dag_last_modified} {dbt_combined_last_modified}"
# return f"{dbt_combined_last_modified}"


@functools.lru_cache
def is_project_unmodified(dag_id: str, current_version: str) -> Path | None:
cache_filepath = get_cache_filepath(dag_id)
cache_version_filepath = get_cache_version_filepath(dag_id)
if cache_version_filepath.exists() and cache_filepath.exists():
previous_cache_version = cache_version_filepath.read_text()
if previous_cache_version == current_version:
return cache_filepath
return None
4 changes: 3 additions & 1 deletion cosmos/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@
# In MacOS users may want to set the envvar `TMPDIR` if they do not want the value of the temp directory to change
DEFAULT_CACHE_DIR = Path(tempfile.gettempdir(), DEFAULT_COSMOS_CACHE_DIR_NAME)
cache_dir = Path(conf.get("cosmos", "cache_dir", fallback=DEFAULT_CACHE_DIR) or DEFAULT_CACHE_DIR)
enable_cache = conf.get("cosmos", "enable_cache", fallback=True)

enable_cache = conf.getboolean("cosmos", "enable_cache", fallback=True)
experimental_cache = conf.getboolean("cosmos", "experimental_cache", fallback=False)
propagate_logs = conf.getboolean("cosmos", "propagate_logs", fallback=True)
dbt_docs_dir = conf.get("cosmos", "dbt_docs_dir", fallback=None)
dbt_docs_conn_id = conf.get("cosmos", "dbt_docs_conn_id", fallback=None)
Expand Down