Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 34 additions & 53 deletions tests/test_dbtf.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import os
from datetime import datetime
from pathlib import Path

import pytest
from airflow.utils.state import DagRunState

from cosmos import DbtDag, ExecutionConfig, ProfileConfig, ProjectConfig, RenderConfig
from cosmos.constants import InvocationMode
from cosmos.constants import ExecutionMode, InvocationMode
from cosmos.profiles import GoogleCloudServiceAccountDictProfileMapping

DBT_FUSION_BINARY = Path.home() / ".local/bin/dbt"
Expand Down Expand Up @@ -34,79 +35,56 @@

render_config = RenderConfig(dbt_executable_path=DBT_FUSION_BINARY, invocation_mode=InvocationMode.SUBPROCESS)

execution_config = ExecutionConfig(dbt_executable_path=DBT_FUSION_BINARY, invocation_mode=InvocationMode.SUBPROCESS)


@pytest.mark.integration
@pytest.mark.dbtfusion
def test_dbt_snowflake_dag_with_dbt_fusion():
"""
Run a DbtDag using dbt Fusion.
Confirm it succeeds and has the expected amount of both:
- dbt resources
- Airflow tasks
And that the tasks are in the expected topological order.
"""
snowflake_dag = DbtDag(
execution_config=execution_config,
project_config=project_config,
profile_config=snowflake_profile_config,
render_config=render_config,
start_date=datetime(2023, 1, 1),
dag_id="snowflake_dbt_fusion_dag",
tags=["profiles"],
)
outcome = snowflake_dag.test()
assert outcome.state == DagRunState.SUCCESS

assert len(snowflake_dag.dbt_graph.filtered_nodes) == 23
local_execution_config = ExecutionConfig(
dbt_executable_path=DBT_FUSION_BINARY, invocation_mode=InvocationMode.SUBPROCESS
)

assert len(snowflake_dag.task_dict) == 13
tasks_names = [task.task_id for task in snowflake_dag.topological_sort()]
expected_task_names = [
"raw_customers_seed",
"raw_orders_seed",
"raw_payments_seed",
"stg_customers.run",
"stg_customers.test",
"stg_orders.run",
"stg_orders.test",
"stg_payments.run",
"stg_payments.test",
"customers.run",
"customers.test",
"orders.run",
"orders.test",
]
assert tasks_names == expected_task_names
watcher_execution_config = ExecutionConfig(
execution_mode=ExecutionMode.WATCHER,
dbt_executable_path=DBT_FUSION_BINARY,
invocation_mode=InvocationMode.SUBPROCESS,
)


@pytest.mark.parametrize(
"dag_id,execution_config,profile_config",
[
("dbt_fusion_local_snowflake_dag", local_execution_config, snowflake_profile_config),
("dbt_fusion_local_bigquery_dag", local_execution_config, bigquery_profile_config),
("dbt_fusion_watcher_bigquery_dag", watcher_execution_config, bigquery_profile_config),
],
)
@pytest.mark.integration
@pytest.mark.dbtfusion
def test_dbt_bigquery_dag_with_dbt_fusion():
def test_dbt_fusion(dag_id, execution_config, profile_config):
"""
Run a DbtDag using dbt Fusion.
Confirm it succeeds and has the expected amount of both:
- dbt resources
- Airflow tasks
And that the tasks are in the expected topological order.
"""
bigquery_dag = DbtDag(
if os.getenv("CI"):
operator_args = {"trigger_rule": "all_success"}
else:
operator_args = {}
Comment thread
tatiana marked this conversation as resolved.

dbt_fusion_dag = DbtDag(
execution_config=execution_config,
project_config=project_config,
profile_config=bigquery_profile_config,
profile_config=profile_config,
render_config=render_config,
start_date=datetime(2023, 1, 1),
dag_id="bigquery_dbt_fusion_dag",
dag_id=dag_id,
tags=["profiles"],
operator_args=operator_args,
)
outcome = bigquery_dag.test()
outcome = dbt_fusion_dag.test()
assert outcome.state == DagRunState.SUCCESS
Comment thread
tatiana marked this conversation as resolved.

assert len(bigquery_dag.dbt_graph.filtered_nodes) == 23
assert len(dbt_fusion_dag.dbt_graph.filtered_nodes) == 23

Comment thread
tatiana marked this conversation as resolved.
assert len(bigquery_dag.task_dict) == 13
tasks_names = [task.task_id for task in bigquery_dag.topological_sort()]
tasks_names = [task.task_id for task in dbt_fusion_dag.topological_sort()]
expected_task_names = [
"raw_customers_seed",
"raw_orders_seed",
Expand All @@ -122,4 +100,7 @@ def test_dbt_bigquery_dag_with_dbt_fusion():
"orders.run",
"orders.test",
]
if execution_config.execution_mode == ExecutionMode.WATCHER:
expected_task_names.insert(0, "dbt_producer_watcher")

Comment thread
tatiana marked this conversation as resolved.
assert tasks_names == expected_task_names
Loading