Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions cosmos/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,3 +167,5 @@ def _missing_value_(cls, value): # type: ignore
TELEMETRY_URL = "https://astronomer.gateway.scarf.sh/astronomer-cosmos/{telemetry_version}/{cosmos_version}/{airflow_version}/{python_version}/{platform_system}/{platform_machine}/{event_type}/{status}/{dag_hash}/{task_count}/{cosmos_task_count}"
TELEMETRY_VERSION = "v1"
TELEMETRY_TIMEOUT = 1.0

_AIRFLOW3_VERSION = Version("3.0.0a1")
150 changes: 87 additions & 63 deletions cosmos/operators/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from airflow.models import BaseOperator
from airflow.models.taskinstance import TaskInstance
from airflow.utils.context import Context
from airflow.utils.session import NEW_SESSION, create_session, provide_session
from airflow.version import version as airflow_version
from attr import define
from packaging.version import Version
Expand All @@ -32,7 +31,7 @@
_get_latest_cached_package_lockfile,
is_cache_package_lockfile_enabled,
)
from cosmos.constants import FILE_SCHEME_AIRFLOW_DEFAULT_CONN_ID_MAP, InvocationMode
from cosmos.constants import _AIRFLOW3_VERSION, FILE_SCHEME_AIRFLOW_DEFAULT_CONN_ID_MAP, InvocationMode
from cosmos.dataset import get_dataset_alias_name
from cosmos.dbt.project import get_partial_parse_path, has_non_empty_dependencies_file
from cosmos.exceptions import AirflowCompatibilityError, CosmosDbtRunError, CosmosValueError
Expand Down Expand Up @@ -226,54 +225,63 @@ def handle_exception_dbt_runner(self, result: dbtRunnerResult) -> None:
"""dbtRunnerResult has an attribute `success` that is False if the command failed."""
return dbt_runner.handle_exception_if_needed(result)

@provide_session
def store_compiled_sql(self, tmp_project_dir: str, context: Context, session: Session = NEW_SESSION) -> None:
def store_compiled_sql(self, tmp_project_dir: str, context: Context) -> None:
"""
Takes the compiled SQL files from the dbt run and stores them in the compiled_sql rendered template.
Gets called after every dbt run.
"""
if not self.should_store_compiled_sql:
return

compiled_queries = {}
# dbt compiles sql files and stores them in the target directory
for folder_path, _, file_paths in os.walk(os.path.join(tmp_project_dir, "target")):
for file_path in file_paths:
if not file_path.endswith(".sql"):
continue
from airflow.utils.session import NEW_SESSION, provide_session

compiled_sql_path = Path(os.path.join(folder_path, file_path))
compiled_sql = compiled_sql_path.read_text(encoding="utf-8")
@provide_session
def _store_compiled_sql(session: Session = NEW_SESSION) -> None:

relative_path = str(compiled_sql_path.relative_to(tmp_project_dir))
compiled_queries[relative_path] = compiled_sql.strip()
if not self.should_store_compiled_sql:
return

for name, query in compiled_queries.items():
self.compiled_sql += f"-- {name}\n{query}\n\n"
compiled_queries = {}
# dbt compiles sql files and stores them in the target directory
for folder_path, _, file_paths in os.walk(os.path.join(tmp_project_dir, "target")):
for file_path in file_paths:
if not file_path.endswith(".sql"):
continue

self.compiled_sql = self.compiled_sql.strip()
compiled_sql_path = Path(os.path.join(folder_path, file_path))
compiled_sql = compiled_sql_path.read_text(encoding="utf-8")

# need to refresh the rendered task field record in the db because Airflow only does this
# before executing the task, not after
from airflow.models.renderedtifields import RenderedTaskInstanceFields
relative_path = str(compiled_sql_path.relative_to(tmp_project_dir))
compiled_queries[relative_path] = compiled_sql.strip()

ti = context["ti"]
for name, query in compiled_queries.items():
self.compiled_sql += f"-- {name}\n{query}\n\n"

if isinstance(ti, TaskInstance): # verifies ti is a TaskInstance in order to access and use the "task" field
if TYPE_CHECKING:
assert ti.task is not None
ti.task.template_fields = self.template_fields
rtif = RenderedTaskInstanceFields(ti, render_templates=False)
self.compiled_sql = self.compiled_sql.strip()

# delete the old records
session.query(RenderedTaskInstanceFields).filter(
RenderedTaskInstanceFields.dag_id == self.dag_id, # type: ignore[attr-defined]
RenderedTaskInstanceFields.task_id == self.task_id,
RenderedTaskInstanceFields.run_id == ti.run_id,
).delete()
session.add(rtif)
else:
self.log.info("Warning: ti is of type TaskInstancePydantic. Cannot update template_fields.")
# need to refresh the rendered task field record in the db because Airflow only does this
# before executing the task, not after
from airflow.models.renderedtifields import RenderedTaskInstanceFields

ti = context["ti"]

if isinstance(
ti, TaskInstance
): # verifies ti is a TaskInstance in order to access and use the "task" field
if TYPE_CHECKING:
assert ti.task is not None
ti.task.template_fields = self.template_fields
rtif = RenderedTaskInstanceFields(ti, render_templates=False)

# delete the old records
session.query(RenderedTaskInstanceFields).filter(
RenderedTaskInstanceFields.dag_id == self.dag_id, # type: ignore[attr-defined]
RenderedTaskInstanceFields.task_id == self.task_id,
RenderedTaskInstanceFields.run_id == ti.run_id,
).delete()
session.add(rtif)
else:
self.log.info("Warning: ti is of type TaskInstancePydantic. Cannot update template_fields.")

_store_compiled_sql()

@staticmethod
def _configure_remote_target_path() -> tuple[Path, str] | tuple[None, None]:
Expand Down Expand Up @@ -353,28 +361,33 @@ def _delete_sql_files(self, tmp_project_dir: Path, resource_type: str) -> None:
dest_object_storage_path.unlink()
self.log.debug("Deleted %s to %s", file_path, dest_object_storage_path)

@provide_session
def store_freshness_json(self, tmp_project_dir: str, context: Context, session: Session = NEW_SESSION) -> None:
def store_freshness_json(self, tmp_project_dir: str, context: Context) -> None:
"""
Takes the compiled sources.json file from the dbt source freshness and stores it in the freshness rendered template.
Gets called after every dbt run / source freshness.
"""
if not self.should_store_compiled_sql:
return
from airflow.utils.session import NEW_SESSION, provide_session

sources_json_path = Path(os.path.join(tmp_project_dir, "target", "sources.json"))
@provide_session
def _store_freshness_json(session: Session = NEW_SESSION) -> None:
if not self.should_store_compiled_sql:
return

if sources_json_path.exists():
sources_json_content = sources_json_path.read_text(encoding="utf-8").strip()
sources_json_path = Path(os.path.join(tmp_project_dir, "target", "sources.json"))

sources_data = json.loads(sources_json_content)
if sources_json_path.exists():
sources_json_content = sources_json_path.read_text(encoding="utf-8").strip()

formatted_sources_json = json.dumps(sources_data, indent=4)
sources_data = json.loads(sources_json_content)

self.freshness = formatted_sources_json
formatted_sources_json = json.dumps(sources_data, indent=4)

else:
self.freshness = ""
self.freshness = formatted_sources_json

else:
self.freshness = ""

_store_freshness_json()

def run_subprocess(self, command: list[str], env: dict[str, str], cwd: str) -> FullOutputSubprocessResult:
self.log.info("Trying to run the command:\n %s\nFrom %s", command, cwd)
Expand Down Expand Up @@ -464,7 +477,8 @@ def _handle_datasets(self, context: Context) -> None:
outlets = self.get_datasets("outputs")
self.log.info("Inlets: %s", inlets)
self.log.info("Outlets: %s", outlets)
self.register_dataset(inlets, outlets, context)
if AIRFLOW_VERSION < _AIRFLOW3_VERSION:
Copy link
Copy Markdown
Contributor

@pankajkoti pankajkoti Mar 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel it maybe better to introduce these conditions in the respective method instead of adding the conditional here and in those methods based on the Airflow version, for now we could decide on whether to call the inner function or not. Later on when we add support for AF3, we could call Airflow version specific inner functions within those functions to handle both versions. WDYT?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought it, but I found this approach easier for handling decorator calls like provide_session. However, I'll give it some more thought and, if possible, will create a follow-up PR.

self.register_dataset(inlets, outlets, context)

def _update_partial_parse_cache(self, tmp_dir_path: Path) -> None:
if self.cache_dir is None:
Expand All @@ -474,8 +488,9 @@ def _update_partial_parse_cache(self, tmp_dir_path: Path) -> None:
cache._update_partial_parse_cache(partial_parse_file, self.cache_dir)

def _handle_post_execution(self, tmp_project_dir: str, context: Context) -> None:
self.store_freshness_json(tmp_project_dir, context)
self.store_compiled_sql(tmp_project_dir, context)
if AIRFLOW_VERSION < _AIRFLOW3_VERSION:
self.store_freshness_json(tmp_project_dir, context)
self.store_compiled_sql(tmp_project_dir, context)
if self.should_upload_compiled_sql:
self._upload_sql_files(tmp_project_dir, "compiled")
if self.callback:
Expand Down Expand Up @@ -641,6 +656,8 @@ def register_dataset(self, new_inlets: list[Dataset], new_outlets: list[Dataset]
with DatasetAlias:
https://github.com/apache/airflow/issues/42495
"""
from airflow.utils.session import create_session

if AIRFLOW_VERSION < Version("2.10") or not settings.enable_dataset_alias:
logger.info("Assigning inlets/outlets without DatasetAlias")
with create_session() as session:
Expand Down Expand Up @@ -742,17 +759,24 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
if arg_key in base_operator_args:
base_operator_kwargs[arg_key] = arg_value
AbstractDbtLocalBase.__init__(self, **abstract_dbt_local_base_kwargs)
if kwargs.get("emit_datasets", True) and settings.enable_dataset_alias and AIRFLOW_VERSION >= Version("2.10"):
from airflow.datasets import DatasetAlias

# ignoring the type because older versions of Airflow raise the follow error in mypy
# error: Incompatible types in assignment (expression has type "list[DatasetAlias]", target has type "str")
dag_id = kwargs.get("dag")
task_group_id = kwargs.get("task_group")
base_operator_kwargs["outlets"] = [
DatasetAlias(name=get_dataset_alias_name(dag_id, task_group_id, self.task_id))
] # type: ignore
BaseOperator.__init__(self, **base_operator_kwargs)
if AIRFLOW_VERSION < _AIRFLOW3_VERSION:
if (
kwargs.get("emit_datasets", True)
and settings.enable_dataset_alias
and AIRFLOW_VERSION >= Version("2.10")
):
from airflow.datasets import DatasetAlias

# ignoring the type because older versions of Airflow raise the follow error in mypy
# error: Incompatible types in assignment (expression has type "list[DatasetAlias]", target has type "str")
dag_id = kwargs.get("dag")
task_group_id = kwargs.get("task_group")
base_operator_kwargs["outlets"] = [
DatasetAlias(name=get_dataset_alias_name(dag_id, task_group_id, self.task_id))
] # type: ignore
if "task_id" in base_operator_kwargs:
base_operator_kwargs.pop("task_id")
Comment thread
pankajastro marked this conversation as resolved.
BaseOperator.__init__(self, task_id=self.task_id, **base_operator_kwargs)


class DbtBuildLocalOperator(DbtBuildMixin, DbtLocalBaseOperator):
Expand Down
50 changes: 50 additions & 0 deletions scripts/airflow3/dags/example_operators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import os
from datetime import datetime
from pathlib import Path

from airflow import DAG

from cosmos.config import ProfileConfig
from cosmos.operators.local import DbtCloneLocalOperator, DbtRunLocalOperator, DbtSeedLocalOperator

DEFAULT_DBT_ROOT_PATH = Path(__file__).parent.parent.parent.parent / "dev/dags/dbt"
DBT_ROOT_PATH = Path(os.getenv("DBT_ROOT_PATH", DEFAULT_DBT_ROOT_PATH))
DBT_PROJ_DIR = DBT_ROOT_PATH / "jaffle_shop"
DBT_PROFILE_PATH = DBT_PROJ_DIR / "profiles.yml"
DBT_ARTIFACT = DBT_PROJ_DIR / "target"

profile_config = ProfileConfig(
profile_name="default",
target_name="dev",
profiles_yml_filepath=DBT_PROFILE_PATH,
)

with DAG("example_operators", start_date=datetime(2024, 1, 1), catchup=False) as dag:
seed_operator = DbtSeedLocalOperator(
profile_config=profile_config,
project_dir=DBT_PROJ_DIR,
task_id="seed",
dbt_cmd_flags=["--select", "raw_customers"],
install_deps=True,
append_env=True,
)

clone_operator = DbtCloneLocalOperator(
profile_config=profile_config,
project_dir=DBT_PROJ_DIR,
task_id="clone",
dbt_cmd_flags=["--models", "stg_customers", "--state", DBT_ARTIFACT],
install_deps=True,
append_env=True,
)

run_operator = DbtRunLocalOperator(
profile_config=profile_config,
project_dir=DBT_PROJ_DIR,
task_id="run",
dbt_cmd_flags=["--models", "stg_customers"],
install_deps=True,
append_env=True,
)

seed_operator >> [run_operator, clone_operator]
5 changes: 5 additions & 0 deletions scripts/airflow3/env.sh
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
#!/bin/bash

set -e
set -x

AIRFLOW_HOME="$PWD/scripts/airflow3"
export AIRFLOW_HOME
export AIRFLOW__LOGGING__BASE_LOG_FOLDER="$AIRFLOW_HOME/logs"
Expand Down
12 changes: 10 additions & 2 deletions scripts/airflow3/setup.sh
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,15 @@ source "$(pwd)/scripts/airflow3/venv/bin/activate"
echo "Installing dependencies..."
uv pip install -r "$(pwd)/scripts/airflow3/requirements.txt"

# Optional: Install additional dependencies
uv pip install astronomer-cosmos
# Install Cosmos
uv pip install build
rm -rf "$(pwd)/dist/"
python3 -m build
for wheel in "$(pwd)"/dist/*.whl; do
uv pip install "$wheel"
done

uv pip install dbt-core
uv pip install dbt-postgres

echo "UV virtual environment setup and dependencies installed successfully!"
6 changes: 3 additions & 3 deletions tests/operators/test_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -1153,7 +1153,7 @@ def test_store_freshness_json(mock_path_class, mock_context, mock_session):
expected_freshness = json.dumps({"key": "value"}, indent=4)

# Call the method under test
instance.store_freshness_json(tmp_project_dir="/mock/dir", context=mock_context, session=mock_session)
instance.store_freshness_json(tmp_project_dir="/mock/dir", context=mock_context)

# Verify the freshness attribute is set correctly
assert instance.freshness == expected_freshness
Expand All @@ -1174,7 +1174,7 @@ def test_store_freshness_json_no_file(mock_path_class, mock_context, mock_sessio
mock_sources_json_path.exists.return_value = False

# Call the method under test
instance.store_freshness_json(tmp_project_dir="/mock/dir", context=mock_context, session=mock_session)
instance.store_freshness_json(tmp_project_dir="/mock/dir", context=mock_context)

# Verify the freshness attribute is set correctly
assert instance.freshness == ""
Expand All @@ -1189,7 +1189,7 @@ def test_store_freshness_not_store_compiled_sql(mock_context, mock_session):
)

# Call the method under test
instance.store_freshness_json(tmp_project_dir="/mock/dir", context=mock_context, session=mock_session)
instance.store_freshness_json(tmp_project_dir="/mock/dir", context=mock_context)

# Verify the freshness attribute is set correctly
assert instance.freshness == ""
Expand Down