Skip to content
Merged
12 changes: 11 additions & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ name: test

on:
push: # Run on pushes to the default branch
branches: [main]
branches: [main,poc-dbt-compile-task]
Comment thread
pankajkoti marked this conversation as resolved.
pull_request_target: # Also run on pull requests originated from forks
branches: [main]

Expand Down Expand Up @@ -176,6 +176,8 @@ jobs:
POSTGRES_DB: postgres
POSTGRES_SCHEMA: public
POSTGRES_PORT: 5432
AIRFLOW__COSMOS__REMOTE_TARGET_PATH: "s3://cosmos-remote-cache/target_compiled/"
AIRFLOW__COSMOS__REMOTE_TARGET_PATH_CONN_ID: aws_s3_conn

- name: Upload coverage to Github
uses: actions/upload-artifact@v4
Expand Down Expand Up @@ -248,6 +250,8 @@ jobs:
POSTGRES_DB: postgres
POSTGRES_SCHEMA: public
POSTGRES_PORT: 5432
AIRFLOW__COSMOS__REMOTE_TARGET_PATH: "s3://cosmos-remote-cache/target_compiled/"
AIRFLOW__COSMOS__REMOTE_TARGET_PATH_CONN_ID: aws_s3_conn

- name: Upload coverage to Github
uses: actions/upload-artifact@v4
Expand Down Expand Up @@ -316,6 +320,8 @@ jobs:
POSTGRES_DB: postgres
POSTGRES_SCHEMA: public
POSTGRES_PORT: 5432
AIRFLOW__COSMOS__REMOTE_TARGET_PATH: "s3://cosmos-remote-cache/target_compiled/"
AIRFLOW__COSMOS__REMOTE_TARGET_PATH_CONN_ID: aws_s3_conn

- name: Upload coverage to Github
uses: actions/upload-artifact@v4
Expand Down Expand Up @@ -393,6 +399,8 @@ jobs:
POSTGRES_DB: postgres
POSTGRES_SCHEMA: public
POSTGRES_PORT: 5432
AIRFLOW__COSMOS__REMOTE_TARGET_PATH: "s3://cosmos-remote-cache/target_compiled/"
AIRFLOW__COSMOS__REMOTE_TARGET_PATH_CONN_ID: aws_s3_conn

- name: Upload coverage to Github
uses: actions/upload-artifact@v4
Expand Down Expand Up @@ -537,6 +545,8 @@ jobs:
POSTGRES_DB: postgres
POSTGRES_SCHEMA: public
POSTGRES_PORT: 5432
AIRFLOW__COSMOS__REMOTE_TARGET_PATH: "s3://cosmos-remote-cache/target_compiled/"
AIRFLOW__COSMOS__REMOTE_TARGET_PATH_CONN_ID: aws_s3_conn

- name: Upload coverage to Github
uses: actions/upload-artifact@v4
Expand Down
31 changes: 30 additions & 1 deletion cosmos/airflow/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from cosmos.config import RenderConfig
from cosmos.constants import (
DBT_COMPILE_TASK_ID,
DEFAULT_DBT_RESOURCES,
TESTABLE_DBT_RESOURCES,
DbtResourceType,
Expand Down Expand Up @@ -252,6 +253,31 @@ def generate_task_or_group(
return task_or_group


def _add_dbt_compile_task(
nodes: dict[str, DbtNode],
dag: DAG,
execution_mode: ExecutionMode,
task_args: dict[str, Any],
tasks_map: dict[str, Any],
task_group: TaskGroup | None,
) -> None:
if execution_mode != ExecutionMode.AIRFLOW_ASYNC:
return

compile_task_metadata = TaskMetadata(
id=DBT_COMPILE_TASK_ID,
operator_class="cosmos.operators.airflow_async.DbtCompileAirflowAsyncOperator",
arguments=task_args,
extra_context={},
)
compile_airflow_task = create_airflow_task(compile_task_metadata, dag, task_group=task_group)
tasks_map[DBT_COMPILE_TASK_ID] = compile_airflow_task

for node_id, node in nodes.items():
if not node.depends_on and node_id in tasks_map:
tasks_map[DBT_COMPILE_TASK_ID] >> tasks_map[node_id]


def build_airflow_graph(
nodes: dict[str, DbtNode],
dag: DAG, # Airflow-specific - parent DAG where to associate tasks and (optional) task groups
Expand Down Expand Up @@ -332,11 +358,14 @@ def build_airflow_graph(
for leaf_node_id in leaves_ids:
tasks_map[leaf_node_id] >> test_task

_add_dbt_compile_task(nodes, dag, execution_mode, task_args, tasks_map, task_group)

create_airflow_task_dependencies(nodes, tasks_map)


def create_airflow_task_dependencies(
nodes: dict[str, DbtNode], tasks_map: dict[str, Union[TaskGroup, BaseOperator]]
nodes: dict[str, DbtNode],
tasks_map: dict[str, Union[TaskGroup, BaseOperator]],
) -> None:
"""
Create the Airflow task dependencies between non-test nodes.
Expand Down
3 changes: 3 additions & 0 deletions cosmos/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ class ExecutionMode(Enum):
"""

LOCAL = "local"
AIRFLOW_ASYNC = "airflow_async"
DOCKER = "docker"
KUBERNETES = "kubernetes"
AWS_EKS = "aws_eks"
Expand Down Expand Up @@ -147,3 +148,5 @@ def _missing_value_(cls, value): # type: ignore
# It expects that you have already created those resources through the appropriate commands.
# https://docs.getdbt.com/reference/commands/test
TESTABLE_DBT_RESOURCES = {DbtResourceType.MODEL, DbtResourceType.SOURCE, DbtResourceType.SNAPSHOT, DbtResourceType.SEED}

DBT_COMPILE_TASK_ID = "dbt_compile"
67 changes: 67 additions & 0 deletions cosmos/operators/airflow_async.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
from cosmos.operators.local import (
DbtBuildLocalOperator,
DbtCompileLocalOperator,
DbtDocsAzureStorageLocalOperator,
DbtDocsGCSLocalOperator,
DbtDocsLocalOperator,
DbtDocsS3LocalOperator,
DbtLSLocalOperator,
DbtRunLocalOperator,
DbtRunOperationLocalOperator,
DbtSeedLocalOperator,
DbtSnapshotLocalOperator,
DbtSourceLocalOperator,
DbtTestLocalOperator,
)


class DbtBuildAirflowAsyncOperator(DbtBuildLocalOperator):
pass


class DbtLSAirflowAsyncOperator(DbtLSLocalOperator):
pass


class DbtSeedAirflowAsyncOperator(DbtSeedLocalOperator):
pass


class DbtSnapshotAirflowAsyncOperator(DbtSnapshotLocalOperator):
pass


class DbtSourceAirflowAsyncOperator(DbtSourceLocalOperator):
pass


class DbtRunAirflowAsyncOperator(DbtRunLocalOperator):
pass


class DbtTestAirflowAsyncOperator(DbtTestLocalOperator):
pass


class DbtRunOperationAirflowAsyncOperator(DbtRunOperationLocalOperator):
pass


class DbtDocsAirflowAsyncOperator(DbtDocsLocalOperator):
pass


class DbtDocsS3AirflowAsyncOperator(DbtDocsS3LocalOperator):
pass


class DbtDocsAzureStorageAirflowAsyncOperator(DbtDocsAzureStorageLocalOperator):
pass


class DbtDocsGCSAirflowAsyncOperator(DbtDocsGCSLocalOperator):
pass


class DbtCompileAirflowAsyncOperator(DbtCompileLocalOperator):
pass
9 changes: 9 additions & 0 deletions cosmos/operators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,3 +429,12 @@ def add_cmd_flags(self) -> list[str]:
flags.append("--args")
flags.append(yaml.dump(self.args))
return flags


class DbtCompileMixin:
"""
Mixin for dbt compile command.
"""

base_cmd = ["compile"]
ui_color = "#877c7c"
95 changes: 93 additions & 2 deletions cosmos/operators/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from functools import cached_property
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Literal, Sequence
from urllib.parse import urlparse

import airflow
import jinja2
Expand All @@ -17,6 +18,7 @@
from airflow.models.taskinstance import TaskInstance
from airflow.utils.context import Context
from airflow.utils.session import NEW_SESSION, create_session, provide_session
from airflow.version import version as airflow_version
from attr import define
from packaging.version import Version

Expand All @@ -26,10 +28,11 @@
_get_latest_cached_package_lockfile,
is_cache_package_lockfile_enabled,
)
from cosmos.constants import InvocationMode
from cosmos.constants import FILE_SCHEME_AIRFLOW_DEFAULT_CONN_ID_MAP, InvocationMode
from cosmos.dataset import get_dataset_alias_name
from cosmos.dbt.project import get_partial_parse_path, has_non_empty_dependencies_file
from cosmos.exceptions import AirflowCompatibilityError
from cosmos.exceptions import AirflowCompatibilityError, CosmosValueError
from cosmos.settings import AIRFLOW_IO_AVAILABLE, remote_target_path, remote_target_path_conn_id

try:
from airflow.datasets import Dataset
Expand Down Expand Up @@ -67,6 +70,7 @@
from cosmos.operators.base import (
AbstractDbtBaseOperator,
DbtBuildMixin,
DbtCompileMixin,
DbtLSMixin,
DbtRunMixin,
DbtRunOperationMixin,
Expand Down Expand Up @@ -137,6 +141,7 @@ def __init__(
install_deps: bool = False,
callback: Callable[[str], None] | None = None,
should_store_compiled_sql: bool = True,
should_upload_compiled_sql: bool = False,
append_env: bool = True,
**kwargs: Any,
) -> None:
Expand All @@ -146,6 +151,7 @@ def __init__(
self.compiled_sql = ""
self.freshness = ""
self.should_store_compiled_sql = should_store_compiled_sql
self.should_upload_compiled_sql = should_upload_compiled_sql
self.openlineage_events_completes: list[RunEvent] = []
self.invocation_mode = invocation_mode
self.invoke_dbt: Callable[..., FullOutputSubprocessResult | dbtRunnerResult]
Expand Down Expand Up @@ -271,6 +277,84 @@ def store_compiled_sql(self, tmp_project_dir: str, context: Context, session: Se
else:
self.log.info("Warning: ti is of type TaskInstancePydantic. Cannot update template_fields.")

@staticmethod
def _configure_remote_target_path() -> tuple[Path, str] | tuple[None, None]:
"""Configure the remote target path if it is provided."""
if not remote_target_path:
return None, None

_configured_target_path = None

target_path_str = str(remote_target_path)

remote_conn_id = remote_target_path_conn_id
if not remote_conn_id:
target_path_schema = urlparse(target_path_str).scheme
remote_conn_id = FILE_SCHEME_AIRFLOW_DEFAULT_CONN_ID_MAP.get(target_path_schema, None) # type: ignore[assignment]
if remote_conn_id is None:
Comment thread
pankajkoti marked this conversation as resolved.
return None, None

if not AIRFLOW_IO_AVAILABLE:
raise CosmosValueError(
f"You're trying to specify remote target path {target_path_str}, but the required "
f"Object Storage feature is unavailable in Airflow version {airflow_version}. Please upgrade to "
"Airflow 2.8 or later."
)

from airflow.io.path import ObjectStoragePath

_configured_target_path = ObjectStoragePath(target_path_str, conn_id=remote_conn_id)

if not _configured_target_path.exists(): # type: ignore[no-untyped-call]
_configured_target_path.mkdir(parents=True, exist_ok=True)

return _configured_target_path, remote_conn_id

def _construct_dest_file_path(
self, dest_target_dir: Path, file_path: str, source_compiled_dir: Path, context: Context
) -> str:
"""
Construct the destination path for the compiled SQL files to be uploaded to the remote store.
"""
dest_target_dir_str = str(dest_target_dir).rstrip("/")

task = context["task"]
dag_id = task.dag_id
task_group_id = task.task_group.group_id if task.task_group else None
identifiers_list = []
if dag_id:
identifiers_list.append(dag_id)
if task_group_id:
identifiers_list.append(task_group_id)
dag_task_group_identifier = "__".join(identifiers_list)

rel_path = os.path.relpath(file_path, source_compiled_dir).lstrip("/")

return f"{dest_target_dir_str}/{dag_task_group_identifier}/compiled/{rel_path}"

def upload_compiled_sql(self, tmp_project_dir: str, context: Context) -> None:
"""
Uploads the compiled SQL files from the dbt compile output to the remote store.
"""
if not self.should_upload_compiled_sql:
return

dest_target_dir, dest_conn_id = self._configure_remote_target_path()
if not dest_target_dir:
raise CosmosValueError(
"You're trying to upload compiled SQL files, but the remote target path is not configured. "
)

from airflow.io.path import ObjectStoragePath

source_compiled_dir = Path(tmp_project_dir) / "target" / "compiled"
files = [str(file) for file in source_compiled_dir.rglob("*") if file.is_file()]
for file_path in files:
dest_file_path = self._construct_dest_file_path(dest_target_dir, file_path, source_compiled_dir, context)
dest_object_storage_path = ObjectStoragePath(dest_file_path, conn_id=dest_conn_id)
ObjectStoragePath(file_path).copy(dest_object_storage_path)
self.log.debug("Copied %s to %s", file_path, dest_object_storage_path)

@provide_session
def store_freshness_json(self, tmp_project_dir: str, context: Context, session: Session = NEW_SESSION) -> None:
"""
Expand Down Expand Up @@ -416,6 +500,7 @@ def run_command(

self.store_freshness_json(tmp_project_dir, context)
self.store_compiled_sql(tmp_project_dir, context)
self.upload_compiled_sql(tmp_project_dir, context)
self.handle_exception(result)
if self.callback:
self.callback(tmp_project_dir)
Expand Down Expand Up @@ -920,3 +1005,9 @@ def __init__(self, **kwargs: str) -> None:
raise DeprecationWarning(
"The DbtDepsOperator has been deprecated. " "Please use the `install_deps` flag in dbt_args instead."
)


class DbtCompileLocalOperator(DbtCompileMixin, DbtLocalBaseOperator):
def __init__(self, *args: Any, **kwargs: Any) -> None:
kwargs["should_upload_compiled_sql"] = True
super().__init__(*args, **kwargs)
3 changes: 3 additions & 0 deletions cosmos/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@
remote_cache_dir = conf.get("cosmos", "remote_cache_dir", fallback=None)
remote_cache_dir_conn_id = conf.get("cosmos", "remote_cache_dir_conn_id", fallback=None)

remote_target_path = conf.get("cosmos", "remote_target_path", fallback=None)
remote_target_path_conn_id = conf.get("cosmos", "remote_target_path_conn_id", fallback=None)

try:
LINEAGE_NAMESPACE = conf.get("openlineage", "namespace")
except airflow.exceptions.AirflowConfigException:
Expand Down
Loading