Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
851564f
Draft: dbt compile task
pankajkoti Sep 25, 2024
9dc2c9c
Put compiled files under dag_id folder & refactor few snippets
pankajkoti Sep 29, 2024
0ce662e
Add tests & minor refactorings
pankajkoti Sep 29, 2024
1b6f57e
Apply suggestions from code review
pankajkoti Sep 29, 2024
cc48161
Install deps for the newly added example DAG
pankajkoti Sep 29, 2024
1068025
Add docs
pankajkoti Sep 30, 2024
faa706d
Add async run operator
pankajkoti Sep 25, 2024
0e155e4
Fix remote sql path and async args
pankajastro Sep 30, 2024
5f1ecaa
Fix query
pankajastro Sep 30, 2024
1278847
🎨 [pre-commit.ci] Auto format from pre-commit.com hooks
pre-commit-ci[bot] Sep 30, 2024
b3d6cf3
Use dbt node's filepath to construct remote path to fetch compiled SQ…
pankajkoti Sep 30, 2024
78bc069
Merge branch 'main' into execute-async-task
tatiana Sep 30, 2024
9ca5e85
🎨 [pre-commit.ci] Auto format from pre-commit.com hooks
pre-commit-ci[bot] Sep 30, 2024
99bf7c0
Fix unittests
tatiana Sep 30, 2024
3aaaf9e
Improve code
tatiana Sep 30, 2024
43158be
Working with deferrable=False, not working with deferrable=True
tatiana Oct 1, 2024
83b1010
Working with deferrable=False, not working with deferrable=True
tatiana Oct 1, 2024
bd6657a
Fix issue when using BQ deferrable operator - it requires location
tatiana Oct 1, 2024
1195955
Add limitation in docs
pankajastro Oct 1, 2024
2bdd9bb
Add full_refresh as templated field
pankajastro Oct 1, 2024
4a44603
Add more template fields
pankajastro Oct 1, 2024
c3c51cb
Construct & relay 'dbt dag-task group' identifier to upload & downloa…
pankajkoti Oct 1, 2024
72c6164
Fix model_name retrieval; get from dbt_node_config
pankajkoti Oct 1, 2024
e67098e
Fix unit tests
pankajkoti Oct 1, 2024
3e550bf
Fix subsequent failing unit tests
pankajkoti Oct 1, 2024
0730d0f
Fix type check failures
pankajkoti Oct 1, 2024
745768e
Add back the deleted sources.yml from jaffle_shop as it has dependenc…
pankajkoti Oct 1, 2024
43d62ea
Install dbt bigquery adapter for running simple_dag_async
pankajkoti Oct 1, 2024
9656248
Install dbt bigquery adapter in our CI setup scripts
pankajkoti Oct 1, 2024
a654f49
Update gcp conn in dev/dags/simple_dag_async.py
pankajkoti Oct 1, 2024
e60ace2
Refactor args in DbtRunAirflowAsyncOperator
tatiana Oct 1, 2024
7f055bc
Use GoogleCloudServiceAccountDictProfileMapping in profilemapping
pankajkoti Oct 1, 2024
ad057c8
set should_upload_compiled_sql to True
pankajkoti Oct 1, 2024
a70ca46
Remove async_op_args
tatiana Oct 1, 2024
7c6a1b2
remove install_deps from DAG
pankajkoti Oct 1, 2024
64a31d0
Merge branch 'main' into execute-async-task
tatiana Oct 1, 2024
c1aeff0
Fix test_build_airflow_graph_with_dbt_compile_task by passing needed …
pankajkoti Oct 1, 2024
02f7985
Specify required project id in the GoogleCloudServiceAccountDictProfi…
pankajkoti Oct 2, 2024
af454a9
Pass gcp_conn_id to super class init, otherwise it is lost & uses the…
pankajkoti Oct 2, 2024
9081e6a
Adapt manifest DAG to use & adapt to the newer GCP conn secret that i…
pankajkoti Oct 2, 2024
2dccf84
Release 1.7.0a1
tatiana Oct 2, 2024
7adeb99
Retrigger GH actions
tatiana Oct 2, 2024
7e6de30
temporarily move out simple_dag_async.py
tatiana Oct 2, 2024
16a87ea
Fix CI issue
tatiana Oct 2, 2024
05db6a0
Fix dbt-compile dependency by using Airflow tasks instead of dbt nodes
pankajkoti Oct 2, 2024
8fc4ae2
Apply suggestions from code review
pankajkoti Oct 2, 2024
ea5816b
Apply suggestions from code review
pankajkoti Oct 2, 2024
85f86a4
Add install instruction
pankajastro Oct 3, 2024
402f823
Add min airflow version in limitation
pankajastro Oct 3, 2024
621a4de
Ignore Async DAG for dbt <=1.5
pankajastro Oct 3, 2024
a0cb147
Ignore Async DAG for dbt <=1.5
pankajastro Oct 3, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,3 @@ webserver_config.py

# VI
*.sw[a-z]

# Ignore possibly created symlink to `dev/dags` for running `airflow dags test` command.
dags
3 changes: 2 additions & 1 deletion cosmos/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@

Contains dags, task groups, and operators.
"""
__version__ = "1.6.0"

__version__ = "1.7.0a1"


from cosmos.airflow.dag import DbtDag
Expand Down
33 changes: 26 additions & 7 deletions cosmos/airflow/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ def create_task_metadata(
node: DbtNode,
execution_mode: ExecutionMode,
args: dict[str, Any],
dbt_dag_task_group_identifier: str,
use_task_group: bool = False,
source_rendering_behavior: SourceRenderingBehavior = SourceRenderingBehavior.NONE,
) -> TaskMetadata | None:
Expand All @@ -142,6 +143,7 @@ def create_task_metadata(
:param execution_mode: Where Cosmos should run each dbt task (e.g. ExecutionMode.LOCAL, ExecutionMode.KUBERNETES).
Default is ExecutionMode.LOCAL.
:param args: Arguments to be used to instantiate an Airflow Task
:param dbt_dag_task_group_identifier: Identifier to refer to the DbtDAG or DbtTaskGroup in the DAG.
:param use_task_group: It determines whether to use the name as a prefix for the task id or not.
If it is False, then use the name as a prefix for the task id, otherwise do not.
:returns: The metadata necessary to instantiate the source dbt node as an Airflow task.
Expand All @@ -156,7 +158,10 @@ def create_task_metadata(
args = {**args, **{"models": node.resource_name}}

if DbtResourceType(node.resource_type) in DEFAULT_DBT_RESOURCES and node.resource_type in dbt_resource_to_class:
extra_context = {"dbt_node_config": node.context_dict}
extra_context = {
"dbt_node_config": node.context_dict,
"dbt_dag_task_group_identifier": dbt_dag_task_group_identifier,
}
if node.resource_type == DbtResourceType.MODEL:
task_id = f"{node.name}_run"
if use_task_group is True:
Expand Down Expand Up @@ -226,6 +231,7 @@ def generate_task_or_group(
node=node,
execution_mode=execution_mode,
args=task_args,
dbt_dag_task_group_identifier=_get_dbt_dag_task_group_identifier(dag, task_group),
use_task_group=use_task_group,
source_rendering_behavior=source_rendering_behavior,
)
Expand Down Expand Up @@ -268,14 +274,28 @@ def _add_dbt_compile_task(
id=DBT_COMPILE_TASK_ID,
operator_class="cosmos.operators.airflow_async.DbtCompileAirflowAsyncOperator",
arguments=task_args,
extra_context={},
extra_context={"dbt_dag_task_group_identifier": _get_dbt_dag_task_group_identifier(dag, task_group)},
)
compile_airflow_task = create_airflow_task(compile_task_metadata, dag, task_group=task_group)

for task_id, task in tasks_map.items():
if not task.upstream_list:
compile_airflow_task >> task

tasks_map[DBT_COMPILE_TASK_ID] = compile_airflow_task

for node_id, node in nodes.items():
if not node.depends_on and node_id in tasks_map:
tasks_map[DBT_COMPILE_TASK_ID] >> tasks_map[node_id]

def _get_dbt_dag_task_group_identifier(dag: DAG, task_group: TaskGroup | None) -> str:
dag_id = dag.dag_id
task_group_id = task_group.group_id if task_group else None
identifiers_list = []
if dag_id:
identifiers_list.append(dag_id)
if task_group_id:
identifiers_list.append(task_group_id)
dag_task_group_identifier = "__".join(identifiers_list)

return dag_task_group_identifier


def build_airflow_graph(
Expand Down Expand Up @@ -358,9 +378,8 @@ def build_airflow_graph(
for leaf_node_id in leaves_ids:
tasks_map[leaf_node_id] >> test_task

_add_dbt_compile_task(nodes, dag, execution_mode, task_args, tasks_map, task_group)

create_airflow_task_dependencies(nodes, tasks_map)
_add_dbt_compile_task(nodes, dag, execution_mode, task_args, tasks_map, task_group)


def create_airflow_task_dependencies(
Expand Down
16 changes: 16 additions & 0 deletions cosmos/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from pathlib import Path
from typing import Any, Callable, Iterator

import yaml
from airflow.version import version as airflow_version

from cosmos.cache import create_cache_profile, get_cached_profile, is_profile_cache_enabled
Expand Down Expand Up @@ -286,6 +287,21 @@ def validate_profiles_yml(self) -> None:
if self.profiles_yml_filepath and not Path(self.profiles_yml_filepath).exists():
raise CosmosValueError(f"The file {self.profiles_yml_filepath} does not exist.")

def get_profile_type(self) -> str:
if isinstance(self.profile_mapping, BaseProfileMapping):
return str(self.profile_mapping.dbt_profile_type)

profile_path = self._get_profile_path()

with open(profile_path) as file:
profiles = yaml.safe_load(file)

profile = profiles[self.profile_name]
target_type = profile["outputs"][self.target_name]["type"]
return str(target_type)

return "undefined"

def _get_profile_path(self, use_mock_values: bool = False) -> Path:
"""
Handle the profile caching mechanism.
Expand Down
3 changes: 2 additions & 1 deletion cosmos/core/airflow.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import importlib
from typing import Any

from airflow.models import BaseOperator
from airflow.models.dag import DAG
Expand All @@ -27,7 +28,7 @@ def get_airflow_task(task: Task, dag: DAG, task_group: TaskGroup | None = None)
module = importlib.import_module(module_name)
Operator = getattr(module, class_name)

task_kwargs = {}
task_kwargs: dict[str, Any] = {}
if task.owner != "":
task_kwargs["owner"] = task.owner

Expand Down
1 change: 0 additions & 1 deletion cosmos/dbt/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,7 +467,6 @@ def should_use_dbt_ls_cache(self) -> bool:

def load_via_dbt_ls_cache(self) -> bool:
"""(Try to) load dbt ls cache from an Airflow Variable"""

logger.info(f"Trying to parse the dbt project using dbt ls cache {self.dbt_ls_cache_key}...")
if self.should_use_dbt_ls_cache():
project_path = self.project_path
Expand Down
179 changes: 151 additions & 28 deletions cosmos/operators/airflow_async.py
Original file line number Diff line number Diff line change
@@ -1,67 +1,190 @@
from __future__ import annotations

import inspect
from pathlib import Path
from typing import TYPE_CHECKING, Any, Sequence

from airflow.providers.google.cloud.hooks.bigquery import BigQueryHook
from airflow.providers.google.cloud.operators.bigquery import BigQueryInsertJobOperator
from airflow.utils.context import Context

from cosmos import settings
from cosmos.config import ProfileConfig
from cosmos.exceptions import CosmosValueError
from cosmos.operators.base import AbstractDbtBaseOperator
from cosmos.operators.local import (
DbtBuildLocalOperator,
DbtCompileLocalOperator,
DbtDocsAzureStorageLocalOperator,
DbtDocsGCSLocalOperator,
DbtDocsLocalOperator,
DbtDocsS3LocalOperator,
DbtLocalBaseOperator,
DbtLSLocalOperator,
DbtRunLocalOperator,
DbtRunOperationLocalOperator,
DbtSeedLocalOperator,
DbtSnapshotLocalOperator,
DbtSourceLocalOperator,
DbtTestLocalOperator,
)
from cosmos.settings import remote_target_path, remote_target_path_conn_id

_SUPPORTED_DATABASES = ["bigquery"]

class DbtBuildAirflowAsyncOperator(DbtBuildLocalOperator):
pass
from abc import ABCMeta


class DbtLSAirflowAsyncOperator(DbtLSLocalOperator):
pass
from airflow.models.baseoperator import BaseOperator


class DbtSeedAirflowAsyncOperator(DbtSeedLocalOperator):
pass


class DbtSnapshotAirflowAsyncOperator(DbtSnapshotLocalOperator):
pass


class DbtSourceAirflowAsyncOperator(DbtSourceLocalOperator):
pass
class DbtBaseAirflowAsyncOperator(BaseOperator, metaclass=ABCMeta):
def __init__(self, **kwargs) -> None: # type: ignore
self.location = kwargs.pop("location")
self.configuration = kwargs.pop("configuration", {})
super().__init__(**kwargs)


class DbtRunAirflowAsyncOperator(DbtRunLocalOperator):
class DbtBuildAirflowAsyncOperator(DbtBaseAirflowAsyncOperator, DbtBuildLocalOperator): # type: ignore
pass


class DbtTestAirflowAsyncOperator(DbtTestLocalOperator):
class DbtLSAirflowAsyncOperator(DbtBaseAirflowAsyncOperator, DbtLSLocalOperator): # type: ignore
pass


class DbtRunOperationAirflowAsyncOperator(DbtRunOperationLocalOperator):
class DbtSeedAirflowAsyncOperator(DbtBaseAirflowAsyncOperator, DbtSeedLocalOperator): # type: ignore
pass


class DbtDocsAirflowAsyncOperator(DbtDocsLocalOperator):
class DbtSnapshotAirflowAsyncOperator(DbtBaseAirflowAsyncOperator, DbtSnapshotLocalOperator): # type: ignore
pass


class DbtDocsS3AirflowAsyncOperator(DbtDocsS3LocalOperator):
class DbtSourceAirflowAsyncOperator(DbtBaseAirflowAsyncOperator, DbtSourceLocalOperator): # type: ignore
pass


class DbtDocsAzureStorageAirflowAsyncOperator(DbtDocsAzureStorageLocalOperator):
class DbtRunAirflowAsyncOperator(BigQueryInsertJobOperator): # type: ignore

template_fields: Sequence[str] = (
"full_refresh",
"project_dir",
"gcp_project",
"dataset",
"location",
)

def __init__( # type: ignore
self,
project_dir: str,
profile_config: ProfileConfig,
location: str, # This is a mandatory parameter when using BigQueryInsertJobOperator with deferrable=True
full_refresh: bool = False,
extra_context: dict[str, object] | None = None,
configuration: dict[str, object] | None = None,
**kwargs,
) -> None:
# dbt task param
self.project_dir = project_dir
self.extra_context = extra_context or {}
self.full_refresh = full_refresh
self.profile_config = profile_config
if not self.profile_config or not self.profile_config.profile_mapping:
raise CosmosValueError(f"Cosmos async support is only available when using ProfileMapping")

self.profile_type: str = profile_config.get_profile_type() # type: ignore
if self.profile_type not in _SUPPORTED_DATABASES:
raise CosmosValueError(f"Async run are only supported: {_SUPPORTED_DATABASES}")

# airflow task param
self.location = location
self.configuration = configuration or {}
self.gcp_conn_id = self.profile_config.profile_mapping.conn_id # type: ignore
profile = self.profile_config.profile_mapping.profile
self.gcp_project = profile["project"]
self.dataset = profile["dataset"]

# Cosmos attempts to pass many kwargs that BigQueryInsertJobOperator simply does not accept.
# We need to pop them.
clean_kwargs = {}
non_async_args = set(inspect.signature(AbstractDbtBaseOperator.__init__).parameters.keys())
non_async_args |= set(inspect.signature(DbtLocalBaseOperator.__init__).parameters.keys())
non_async_args -= {"task_id"}

for arg_key, arg_value in kwargs.items():
if arg_key not in non_async_args:
clean_kwargs[arg_key] = arg_value

# The following are the minimum required parameters to run BigQueryInsertJobOperator using the deferrable mode
super().__init__(
gcp_conn_id=self.gcp_conn_id,
configuration=self.configuration,
location=self.location,
deferrable=True,
**clean_kwargs,
)

def get_remote_sql(self) -> str:
if not settings.AIRFLOW_IO_AVAILABLE:
raise CosmosValueError(f"Cosmos async support is only available starting in Airflow 2.8 or later.")
from airflow.io.path import ObjectStoragePath

file_path = self.extra_context["dbt_node_config"]["file_path"] # type: ignore
dbt_dag_task_group_identifier = self.extra_context["dbt_dag_task_group_identifier"]

remote_target_path_str = str(remote_target_path).rstrip("/")

if TYPE_CHECKING:
assert self.project_dir is not None

project_dir_parent = str(Path(self.project_dir).parent)
relative_file_path = str(file_path).replace(project_dir_parent, "").lstrip("/")
remote_model_path = f"{remote_target_path_str}/{dbt_dag_task_group_identifier}/compiled/{relative_file_path}"

object_storage_path = ObjectStoragePath(remote_model_path, conn_id=remote_target_path_conn_id)
with object_storage_path.open() as fp: # type: ignore
return fp.read() # type: ignore

def drop_table_sql(self) -> None:
model_name = self.extra_context["dbt_node_config"]["resource_name"] # type: ignore
sql = f"DROP TABLE IF EXISTS {self.gcp_project}.{self.dataset}.{model_name};"
Comment thread
pankajkoti marked this conversation as resolved.

hook = BigQueryHook(
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,
)
self.configuration = {
"query": {
"query": sql,
"useLegacySql": False,
}
}
hook.insert_job(configuration=self.configuration, location=self.location, project_id=self.gcp_project)

def execute(self, context: Context) -> Any | None:
if not self.full_refresh:
raise CosmosValueError("The async execution only supported for full_refresh")
else:
# It may be surprising to some, but the dbt-core --full-refresh argument fully drops the table before populating it
# https://github.com/dbt-labs/dbt-core/blob/5e9f1b515f37dfe6cdae1ab1aa7d190b92490e24/core/dbt/context/base.py#L662-L666
# https://docs.getdbt.com/reference/resource-configs/full_refresh#recommendation
# We're emulating this behaviour here
self.drop_table_sql()
sql = self.get_remote_sql()
model_name = self.extra_context["dbt_node_config"]["resource_name"] # type: ignore
# prefix explicit create command to create table
sql = f"CREATE TABLE {self.gcp_project}.{self.dataset}.{model_name} AS {sql}"
self.configuration = {
"query": {
"query": sql,
"useLegacySql": False,
}
}
return super().execute(context)


class DbtTestAirflowAsyncOperator(DbtBaseAirflowAsyncOperator, DbtTestLocalOperator): # type: ignore
pass


class DbtDocsGCSAirflowAsyncOperator(DbtDocsGCSLocalOperator):
class DbtRunOperationAirflowAsyncOperator(DbtBaseAirflowAsyncOperator, DbtRunOperationLocalOperator): # type: ignore
pass


class DbtCompileAirflowAsyncOperator(DbtCompileLocalOperator):
class DbtCompileAirflowAsyncOperator(DbtBaseAirflowAsyncOperator, DbtCompileLocalOperator): # type: ignore
pass
Loading