Skip to content

Commit a578b2f

Browse files
committed
fix: Just run Airflow 3.0 tests like before
1 parent 83cfdff commit a578b2f

File tree

2 files changed

+54
-19
lines changed

2 files changed

+54
-19
lines changed

airflow_dbt_python/utils/version.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,3 +40,4 @@ def _get_base_airflow_version_tuple() -> tuple[int, int, int]:
4040

4141
AIRFLOW_V_3_0_PLUS = _get_base_airflow_version_tuple() >= (3, 0, 0)
4242
AIRFLOW_V_3_1_PLUS = _get_base_airflow_version_tuple() >= (3, 1, 0)
43+
AIRFLOW_V_3_0 = AIRFLOW_V_3_0_PLUS and not AIRFLOW_V_3_1_PLUS

tests/dags/test_dbt_dags.py

Lines changed: 53 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from airflow.models import DagBag, DagModel, DagRun, DagTag
1111
from airflow.models.dag import DagOwnerAttributes
1212
from airflow.models.serialized_dag import SerializedDagModel
13-
from airflow.providers.common.compat.sdk import DAG, DagRunState, TaskInstanceState
13+
from airflow.providers.common.compat.sdk import DagRunState, TaskInstanceState
1414
from airflow.utils.session import create_session
1515
from airflow.utils.types import DagRunType
1616
from dbt.contracts.results import RunStatus, TestStatus
@@ -22,7 +22,17 @@
2222
DbtSourceFreshnessOperator,
2323
DbtTestOperator,
2424
)
25-
from airflow_dbt_python.utils.version import AIRFLOW_V_3_0_PLUS, AIRFLOW_V_3_1_PLUS
25+
from airflow_dbt_python.utils.version import (
26+
AIRFLOW_V_3_0,
27+
AIRFLOW_V_3_0_PLUS,
28+
AIRFLOW_V_3_1_PLUS,
29+
)
30+
31+
if AIRFLOW_V_3_0:
32+
# For some reason Airflow 3.0 cannot use dag.test()
33+
from airflow import DAG
34+
else:
35+
from airflow.providers.common.compat.sdk import DAG
2636

2737
DATA_INTERVAL_START = pendulum.datetime(2022, 1, 1, tz="UTC")
2838
DATA_INTERVAL_END = DATA_INTERVAL_START + dt.timedelta(hours=1)
@@ -33,6 +43,7 @@ def sync_dag_to_db(
3343
bundle_name: str = "testing",
3444
):
3545
"""Sync dags into the database."""
46+
from airflow.models.dagbundle import DagBundleModel
3647
from airflow.models.serialized_dag import SerializedDagModel
3748
from airflow.serialization.serialized_objects import (
3849
LazyDeserializedDAG,
@@ -41,11 +52,8 @@ def sync_dag_to_db(
4152
from airflow.utils.session import create_session
4253

4354
with create_session() as session:
44-
if AIRFLOW_V_3_1_PLUS:
45-
from airflow.models.dagbundle import DagBundleModel
46-
47-
session.merge(DagBundleModel(name=bundle_name))
48-
session.flush()
55+
session.merge(DagBundleModel(name=bundle_name))
56+
session.flush()
4957

5058
def _write_dag(dag: DAG) -> SerializedDAG:
5159
if not SerializedDagModel.has_dag(dag.dag_id):
@@ -68,9 +76,26 @@ def _create_dagrun(
6876
start_date: dt.datetime,
6977
run_type: DagRunType,
7078
) -> DagRun:
71-
if AIRFLOW_V_3_0_PLUS:
79+
if AIRFLOW_V_3_1_PLUS:
7280
return parent_dag.test()
7381

82+
elif AIRFLOW_V_3_0_PLUS:
83+
from airflow.utils.types import DagRunTriggeredByType # type: ignore
84+
85+
return parent_dag.create_dagrun( # type: ignore
86+
run_id=f"{parent_dag.dag_id}-{logical_date.isoformat()}-RUN",
87+
state=state,
88+
logical_date=logical_date,
89+
data_interval=data_interval,
90+
start_date=start_date,
91+
conf={},
92+
backfill_id=None,
93+
creating_job_id=None,
94+
run_type=run_type,
95+
run_after=dt.datetime(1970, 1, 1, 0, 0, 0, tzinfo=dt.timezone.utc),
96+
triggered_by=DagRunTriggeredByType.TIMETABLE,
97+
)
98+
7499
else:
75100
return parent_dag.create_dagrun( # type: ignore
76101
state=state,
@@ -102,9 +127,7 @@ def test_dags_loaded(dagbag):
102127
@pytest.fixture
103128
def testing_dag_bundle():
104129
"""Create a DAG bundle for tests."""
105-
from airflow_dbt_python.utils.version import AIRFLOW_V_3_1_PLUS
106-
107-
if AIRFLOW_V_3_1_PLUS:
130+
if AIRFLOW_V_3_0_PLUS:
108131
from airflow.models.dagbundle import DagBundleModel
109132

110133
with create_session() as session:
@@ -131,10 +154,12 @@ def _clear_db():
131154
session.query(DagModel).delete()
132155
session.query(SerializedDagModel).delete()
133156

134-
if AIRFLOW_V_3_1_PLUS:
157+
if AIRFLOW_V_3_0_PLUS:
158+
from airflow.models.dag_version import DagVersion
135159
from airflow.models.dagbundle import DagBundleModel
136160

137161
session.query(DagBundleModel).delete()
162+
session.query(DagVersion).delete()
138163

139164

140165
@pytest.fixture(autouse=True)
@@ -303,7 +328,7 @@ def prepare_dbt_project_dir() -> str:
303328

304329
d = generate_dag()
305330

306-
if AIRFLOW_V_3_0_PLUS:
331+
if AIRFLOW_V_3_1_PLUS:
307332
sync_dag_to_db(d)
308333

309334
return d
@@ -315,8 +340,8 @@ def test_dbt_operators_in_taskflow_dag(
315340
profiles_file,
316341
):
317342
"""Assert DAG contains correct dbt operators when running."""
318-
if AIRFLOW_V_3_0_PLUS:
319-
dag = taskflow_dag
343+
if AIRFLOW_V_3_0:
344+
dag = DAG.from_sdk_dag(taskflow_dag) # type: ignore
320345
else:
321346
dag = taskflow_dag
322347

@@ -445,7 +470,7 @@ def target_connection_dag(
445470

446471
dbt_seed >> dbt_run >> dbt_test
447472

448-
if AIRFLOW_V_3_0_PLUS:
473+
if AIRFLOW_V_3_1_PLUS:
449474
sync_dag_to_db(dag)
450475
return dag
451476

@@ -510,6 +535,9 @@ def test_example_basic_dag(
510535
"""Test the example basic DAG."""
511536
dag = dagbag.get_dag(dag_id="example_basic_dbt")
512537

538+
if AIRFLOW_V_3_0:
539+
dag = DAG.from_sdk_dag(dag) # type: ignore
540+
513541
assert dag is not None
514542
assert len(dag.tasks) == 1
515543

@@ -525,7 +553,7 @@ def test_example_basic_dag(
525553
dbt_run.target = "test"
526554
dbt_run.profile = "default"
527555

528-
if AIRFLOW_V_3_0_PLUS:
556+
if AIRFLOW_V_3_1_PLUS:
529557
sync_dag_to_db(dag)
530558

531559
dagrun = _create_dagrun(
@@ -578,9 +606,12 @@ def test_example_dbt_project_in_github_dag(
578606
assert dag is not None
579607
assert len(dag.tasks) == 3
580608

581-
if AIRFLOW_V_3_0_PLUS:
609+
if AIRFLOW_V_3_1_PLUS:
582610
sync_dag_to_db(dag)
583611

612+
if AIRFLOW_V_3_0:
613+
dag = DAG.from_sdk_dag(dag) # type: ignore
614+
584615
dagrun = _create_dagrun(
585616
dag,
586617
state=DagRunState.RUNNING,
@@ -632,9 +663,12 @@ def test_example_complete_dbt_workflow_dag(
632663
assert dag is not None
633664
assert len(dag.tasks) == 5
634665

635-
if AIRFLOW_V_3_0_PLUS:
666+
if AIRFLOW_V_3_1_PLUS:
636667
sync_dag_to_db(dag)
637668

669+
if AIRFLOW_V_3_0:
670+
dag = DAG.from_sdk_dag(dag) # type: ignore
671+
638672
dagrun = _create_dagrun(
639673
dag,
640674
state=DagRunState.RUNNING,

0 commit comments

Comments
 (0)