Skip to content

Commit

Permalink
test(test_common): rewrite create_dagrun as logical_date is now nullable
Browse files Browse the repository at this point in the history
  • Loading branch information
Lee-W committed Feb 12, 2025
1 parent 2b3a6f1 commit a3ba5a1
Showing 1 changed file with 19 additions and 17 deletions.
36 changes: 19 additions & 17 deletions tests_common/pytest_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -886,7 +886,7 @@ def __exit__(self, type, value, traceback):
else:
self._bag_dag_compat(self.dag)

def create_dagrun(self, *, logical_date=None, **kwargs):
def create_dagrun(self, *, logical_date=None, **kwargs) -> DagRun:
from airflow.utils import timezone
from airflow.utils.state import DagRunState
from airflow.utils.types import DagRunType
Expand All @@ -905,25 +905,26 @@ def create_dagrun(self, *, logical_date=None, **kwargs):
**kwargs,
}

if logical_date:
logical_date = timezone.coerce_datetime(logical_date)

run_type = kwargs.get("run_type", DagRunType.MANUAL)
if not isinstance(run_type, DagRunType):
run_type = DagRunType(run_type)

if logical_date is None:
if run_type == DagRunType.MANUAL:
logical_date = self.start_date
else:
logical_date = dag.next_dagrun_info(None).logical_date
logical_date = timezone.coerce_datetime(logical_date)

try:
data_interval = kwargs["data_interval"]
except KeyError:
if run_type == DagRunType.MANUAL:
data_interval = dag.timetable.infer_manual_data_interval(run_after=logical_date)
else:
data_interval = dag.infer_automated_data_interval(logical_date)
kwargs["data_interval"] = data_interval
if run_type == DagRunType.ASSET_TRIGGERED:
data_interval = None
elif run_type == DagRunType.MANUAL:
data_interval = dag.infer_automated_data_interval(logical_date) if logical_date else None
elif run_type in (DagRunType.BACKFILL_JOB, DagRunType.SCHEDULED):
data_interval = (
dag.infer_automated_data_interval(logical_date)
if logical_date
else dag.next_dagrun_info(None).logical_date
)
else:
raise ValueError(f"Individual DagRunType {run_type}")
kwargs["data_interval"] = data_interval

if "run_id" not in kwargs:
if "run_type" not in kwargs:
Expand All @@ -943,6 +944,7 @@ def create_dagrun(self, *, logical_date=None, **kwargs):
)
kwargs["run_type"] = run_type

logical_date = kwargs.get("logical_date", None)
if AIRFLOW_V_3_0_PLUS:
kwargs.setdefault("triggered_by", DagRunTriggeredByType.TEST)
kwargs["logical_date"] = logical_date
Expand All @@ -963,7 +965,7 @@ def create_dagrun(self, *, logical_date=None, **kwargs):
self.session.commit()
return self.dag_run

def create_dagrun_after(self, dagrun, **kwargs):
def create_dagrun_after(self, dagrun: DagRun, **kwargs) -> DagRun:
next_info = self.dag.next_dagrun_info(self.dag.get_run_data_interval(dagrun))
if next_info is None:
raise ValueError(f"cannot create run after {dagrun}")
Expand Down

0 comments on commit a3ba5a1

Please sign in to comment.