1010from airflow .models import DagBag , DagModel , DagRun , DagTag
1111from airflow .models .dag import DagOwnerAttributes
1212from airflow .models .serialized_dag import SerializedDagModel
13- from airflow .providers .common .compat .sdk import DAG , DagRunState , TaskInstanceState
13+ from airflow .providers .common .compat .sdk import DagRunState , TaskInstanceState
1414from airflow .utils .session import create_session
1515from airflow .utils .types import DagRunType
1616from dbt .contracts .results import RunStatus , TestStatus
2222 DbtSourceFreshnessOperator ,
2323 DbtTestOperator ,
2424)
25- from airflow_dbt_python .utils .version import AIRFLOW_V_3_0_PLUS , AIRFLOW_V_3_1_PLUS
25+ from airflow_dbt_python .utils .version import (
26+ AIRFLOW_V_3_0 ,
27+ AIRFLOW_V_3_0_PLUS ,
28+ AIRFLOW_V_3_1_PLUS ,
29+ )
30+
31+ if AIRFLOW_V_3_0 :
32+ # For some reason Airflow 3.0 cannot use dag.test()
33+ from airflow import DAG
34+ else :
35+ from airflow .providers .common .compat .sdk import DAG
2636
2737DATA_INTERVAL_START = pendulum .datetime (2022 , 1 , 1 , tz = "UTC" )
2838DATA_INTERVAL_END = DATA_INTERVAL_START + dt .timedelta (hours = 1 )
@@ -33,6 +43,7 @@ def sync_dag_to_db(
3343 bundle_name : str = "testing" ,
3444):
3545 """Sync dags into the database."""
46+ from airflow .models .dagbundle import DagBundleModel
3647 from airflow .models .serialized_dag import SerializedDagModel
3748 from airflow .serialization .serialized_objects import (
3849 LazyDeserializedDAG ,
@@ -41,11 +52,8 @@ def sync_dag_to_db(
4152 from airflow .utils .session import create_session
4253
4354 with create_session () as session :
44- if AIRFLOW_V_3_1_PLUS :
45- from airflow .models .dagbundle import DagBundleModel
46-
47- session .merge (DagBundleModel (name = bundle_name ))
48- session .flush ()
55+ session .merge (DagBundleModel (name = bundle_name ))
56+ session .flush ()
4957
5058 def _write_dag (dag : DAG ) -> SerializedDAG :
5159 if not SerializedDagModel .has_dag (dag .dag_id ):
@@ -68,9 +76,26 @@ def _create_dagrun(
6876 start_date : dt .datetime ,
6977 run_type : DagRunType ,
7078) -> DagRun :
71- if AIRFLOW_V_3_0_PLUS :
79+ if AIRFLOW_V_3_1_PLUS :
7280 return parent_dag .test ()
7381
82+ elif AIRFLOW_V_3_0_PLUS :
83+ from airflow .utils .types import DagRunTriggeredByType # type: ignore
84+
85+ return parent_dag .create_dagrun ( # type: ignore
86+ run_id = f"{ parent_dag .dag_id } -{ logical_date .isoformat ()} -RUN" ,
87+ state = state ,
88+ logical_date = logical_date ,
89+ data_interval = data_interval ,
90+ start_date = start_date ,
91+ conf = {},
92+ backfill_id = None ,
93+ creating_job_id = None ,
94+ run_type = run_type ,
95+ run_after = dt .datetime (1970 , 1 , 1 , 0 , 0 , 0 , tzinfo = dt .timezone .utc ),
96+ triggered_by = DagRunTriggeredByType .TIMETABLE ,
97+ )
98+
7499 else :
75100 return parent_dag .create_dagrun ( # type: ignore
76101 state = state ,
@@ -102,9 +127,7 @@ def test_dags_loaded(dagbag):
102127@pytest .fixture
103128def testing_dag_bundle ():
104129 """Create a DAG bundle for tests."""
105- from airflow_dbt_python .utils .version import AIRFLOW_V_3_1_PLUS
106-
107- if AIRFLOW_V_3_1_PLUS :
130+ if AIRFLOW_V_3_0_PLUS :
108131 from airflow .models .dagbundle import DagBundleModel
109132
110133 with create_session () as session :
@@ -131,10 +154,12 @@ def _clear_db():
131154 session .query (DagModel ).delete ()
132155 session .query (SerializedDagModel ).delete ()
133156
134- if AIRFLOW_V_3_1_PLUS :
157+ if AIRFLOW_V_3_0_PLUS :
158+ from airflow .models .dag_version import DagVersion
135159 from airflow .models .dagbundle import DagBundleModel
136160
137161 session .query (DagBundleModel ).delete ()
162+ session .query (DagVersion ).delete ()
138163
139164
140165@pytest .fixture (autouse = True )
@@ -303,7 +328,7 @@ def prepare_dbt_project_dir() -> str:
303328
304329 d = generate_dag ()
305330
306- if AIRFLOW_V_3_0_PLUS :
331+ if AIRFLOW_V_3_1_PLUS :
307332 sync_dag_to_db (d )
308333
309334 return d
@@ -315,8 +340,8 @@ def test_dbt_operators_in_taskflow_dag(
315340 profiles_file ,
316341):
317342 """Assert DAG contains correct dbt operators when running."""
318- if AIRFLOW_V_3_0_PLUS :
319- dag = taskflow_dag
343+ if AIRFLOW_V_3_0 :
344+ dag = DAG . from_sdk_dag ( taskflow_dag ) # type: ignore
320345 else :
321346 dag = taskflow_dag
322347
@@ -445,7 +470,7 @@ def target_connection_dag(
445470
446471 dbt_seed >> dbt_run >> dbt_test
447472
448- if AIRFLOW_V_3_0_PLUS :
473+ if AIRFLOW_V_3_1_PLUS :
449474 sync_dag_to_db (dag )
450475 return dag
451476
@@ -510,6 +535,9 @@ def test_example_basic_dag(
510535 """Test the example basic DAG."""
511536 dag = dagbag .get_dag (dag_id = "example_basic_dbt" )
512537
538+ if AIRFLOW_V_3_0 :
539+ dag = DAG .from_sdk_dag (dag ) # type: ignore
540+
513541 assert dag is not None
514542 assert len (dag .tasks ) == 1
515543
@@ -525,7 +553,7 @@ def test_example_basic_dag(
525553 dbt_run .target = "test"
526554 dbt_run .profile = "default"
527555
528- if AIRFLOW_V_3_0_PLUS :
556+ if AIRFLOW_V_3_1_PLUS :
529557 sync_dag_to_db (dag )
530558
531559 dagrun = _create_dagrun (
@@ -578,9 +606,12 @@ def test_example_dbt_project_in_github_dag(
578606 assert dag is not None
579607 assert len (dag .tasks ) == 3
580608
581- if AIRFLOW_V_3_0_PLUS :
609+ if AIRFLOW_V_3_1_PLUS :
582610 sync_dag_to_db (dag )
583611
612+ if AIRFLOW_V_3_0 :
613+ dag = DAG .from_sdk_dag (dag ) # type: ignore
614+
584615 dagrun = _create_dagrun (
585616 dag ,
586617 state = DagRunState .RUNNING ,
@@ -632,9 +663,12 @@ def test_example_complete_dbt_workflow_dag(
632663 assert dag is not None
633664 assert len (dag .tasks ) == 5
634665
635- if AIRFLOW_V_3_0_PLUS :
666+ if AIRFLOW_V_3_1_PLUS :
636667 sync_dag_to_db (dag )
637668
669+ if AIRFLOW_V_3_0 :
670+ dag = DAG .from_sdk_dag (dag ) # type: ignore
671+
638672 dagrun = _create_dagrun (
639673 dag ,
640674 state = DagRunState .RUNNING ,
0 commit comments