From dd32a129545f353d3e3de0685c07d9770922c3e9 Mon Sep 17 00:00:00 2001 From: pankajastro Date: Fri, 30 May 2025 15:42:03 +0530 Subject: [PATCH 1/3] Fix schedule for AF3 --- dagfactory/__init__.py | 2 +- dagfactory/dagbuilder.py | 93 ++++++++++++- dagfactory/utils.py | 8 ++ tests/test_dagbuilder.py | 288 +++++++++++++++++++++++++++++++++++++++ 4 files changed, 384 insertions(+), 7 deletions(-) diff --git a/dagfactory/__init__.py b/dagfactory/__init__.py index 18274204..75774c15 100644 --- a/dagfactory/__init__.py +++ b/dagfactory/__init__.py @@ -2,7 +2,7 @@ from .dagfactory import DagFactory, load_yaml_dags -__version__ = "0.23.0a4" +__version__ = "0.23.0a5" __all__ = [ "DagFactory", "load_yaml_dags", diff --git a/dagfactory/dagbuilder.py b/dagfactory/dagbuilder.py index 00f673cf..4400a4b3 100644 --- a/dagfactory/dagbuilder.py +++ b/dagfactory/dagbuilder.py @@ -11,12 +11,13 @@ import warnings from copy import deepcopy from datetime import datetime, timedelta -from functools import partial +from functools import partial, reduce from typing import Any, Callable, Dict, List, Tuple, Union from airflow import DAG, configuration from airflow.models import BaseOperator, Variable from airflow.utils.module_loading import import_string +from dateutil.relativedelta import relativedelta from packaging import version from dagfactory.constants import AIRFLOW3_MAJOR_VERSION @@ -729,6 +730,90 @@ def process_file_with_datasets(file: str, datasets_conditions: str) -> Any: datasets_uri = utils.get_datasets_uri_yaml_file(file, list(dataset_map.keys())) return [Dataset(uri) for uri in datasets_uri] + @staticmethod + def _init_watchers(watchers_data): + """Initialize watcher objects from configuration.""" + from dagfactory.utils import _import_from_string + + watchers = [] + for watcher in watchers_data: + watcher_class = _import_from_string(watcher["callable"]) + trigger_data = watcher.get("trigger", {}) + trigger_class = _import_from_string(trigger_data.get("callable")) + trigger_params = trigger_data.get("params", {}) + watchers.append(watcher_class(name=watcher.get("name"), trigger=trigger_class(**trigger_params))) + return watchers + + @staticmethod + def _init_asset(asset_dict: dict): + from airflow.sdk import Asset + + """Initialize an Asset from its configuration dictionary.""" + if "watchers" in asset_dict: + asset_dict["watchers"] = DagBuilder._init_watchers(asset_dict["watchers"]) + return Asset(**asset_dict) + + @staticmethod + def _combine_assets(assets, op: str): + """Combine a list of Asset objects using logical operators.""" + if op == "or": + return reduce(lambda a, b: a | b, assets) + elif op == "and": + return reduce(lambda a, b: a & b, assets) + else: + raise ValueError(f"Unknown operator: {op}") + + @staticmethod + def _asset_schedule(value): + """Recursively parse and construct assets or combinations of assets.""" + if isinstance(value, dict): + if "or" in value: + assets = [DagBuilder._asset_schedule(item) for item in value["or"]] + return DagBuilder._combine_assets(assets, "or") + elif "and" in value: + assets = [DagBuilder._asset_schedule(item) for item in value["and"]] + return DagBuilder._combine_assets(assets, "and") + elif "uri" in value: + return DagBuilder._init_asset(value) + else: + raise ValueError(f"Invalid asset entry: {value}") + elif isinstance(value, list): + return [DagBuilder._init_asset(asset) for asset in value] + else: + raise TypeError(f"Unexpected data type: {type(value)}") + + @staticmethod + def _resolve_schedule(dag_params): + schedule = dag_params.get("schedule") + if schedule is None: + return None + + # Case 1: If schedule is a string, return it directly + if isinstance(schedule, str): + return schedule + + # Case 2: If schedule is a dictionary + if isinstance(schedule, dict): + schedule_type = schedule.get("type") + value = schedule.get("value") + + dispatch = { + "cron": lambda v: v, + "timetable": lambda v: DagBuilder.make_timetable(v.get("callable"), v.get("params", {})), + "timedelta": lambda v: timedelta(**v), + "relativedelta": lambda v: relativedelta(**v), + "assets": lambda v: DagBuilder._asset_schedule(v), + } + + try: + handler = dispatch[schedule_type] + except KeyError: + raise DagFactoryException(f"Schedule type {schedule_type} is not supported.") + + return handler(value) + + raise DagFactoryException(f"Unexpected value for schedule: {schedule}") + @staticmethod def configure_schedule(dag_params: Dict[str, Any], dag_kwargs: Dict[str, Any]) -> None: """ @@ -774,11 +859,7 @@ def configure_schedule(dag_params: Dict[str, Any], dag_kwargs: Dict[str, Any]) - if has_datasets_attr: schedule.pop("datasets") else: - schedule = dag_params.get("schedule") - if schedule.strip().lower() == "none": - dag_kwargs["schedule"] = None - else: - dag_kwargs["schedule"] = schedule + dag_kwargs["schedule"] = DagBuilder._resolve_schedule(dag_params) # pylint: disable=too-many-locals def build(self) -> Dict[str, Union[str, DAG]]: diff --git a/dagfactory/utils.py b/dagfactory/utils.py index 054a79c1..6f3e5261 100644 --- a/dagfactory/utils.py +++ b/dagfactory/utils.py @@ -9,6 +9,7 @@ import sys import types from datetime import date, datetime, timedelta +from importlib import import_module from pathlib import Path from typing import Any, AnyStr, Dict, List, Match, Optional, Pattern, Tuple, Union @@ -18,6 +19,13 @@ from dagfactory.exceptions import DagFactoryException +def _import_from_string(dotted_path: str): + """Import a class or function from a dotted path string.""" + module_path, _, attr = dotted_path.rpartition(".") + module = import_module(module_path) + return getattr(module, attr) + + def get_datetime(date_value: Union[str, datetime, date], timezone: str = "UTC") -> datetime: """ Takes value from DAG config and generates valid datetime. Defaults to diff --git a/tests/test_dagbuilder.py b/tests/test_dagbuilder.py index e3a2e150..59d74d67 100644 --- a/tests/test_dagbuilder.py +++ b/tests/test_dagbuilder.py @@ -23,6 +23,13 @@ one_hour_ago, ) +import yaml +from airflow import DAG +from packaging import version + +from dagfactory.dagbuilder import INSTALLED_AIRFLOW_VERSION, DagBuilder, DagFactoryConfigException, Dataset + + try: from airflow.providers.http.sensors.http import HttpSensor except ImportError: # Airflow < 2.4 @@ -1153,6 +1160,287 @@ def test_make_nested_task_groups(): assert sub_task_group == expected["sub_task_group"].__dict__ +@pytest.mark.skipif(INSTALLED_AIRFLOW_VERSION.major < 3, reason="Requires Airflow >= 3.0.0") +class TestSchedule: + + def test_asset_schedule_list_of_assets(self): + from airflow.sdk import Asset + + yaml_str = """ + - uri: s3://dag1/output_1.txt + extra: + hi: bye + - uri: s3://dag2/output_1.txt + extra: + hi: bye + """ + data = yaml.safe_load(yaml_str) + parsed_schedule = DagBuilder._asset_schedule(data) + + expected = [ + Asset( + name="s3://dag1/output_1.txt", + uri="s3://dag1/output_1.txt", + group="asset", + extra={"hi": "bye"}, + watchers=[], + ), + Asset( + name="s3://dag2/output_1.txt", + uri="s3://dag2/output_1.txt", + group="asset", + extra={"hi": "bye"}, + watchers=[], + ), + ] + assert parsed_schedule == expected + + def test_asset_schedule_with_and_operator(self): + from airflow.sdk import Asset, AssetAll + + yaml_str = """ + and: + - uri: s3://dag1/output_1.txt + extra: + hi: bye + - uri: s3://dag2/output_1.txt + extra: + hi: bye + """ + data = yaml.safe_load(yaml_str) + parsed_schedule = DagBuilder._asset_schedule(data) + + expected = AssetAll( + Asset( + name="s3://dag1/output_1.txt", + uri="s3://dag1/output_1.txt", + group="asset", + extra={"hi": "bye"}, + watchers=[], + ), + Asset( + name="s3://dag2/output_1.txt", + uri="s3://dag2/output_1.txt", + group="asset", + extra={"hi": "bye"}, + watchers=[], + ), + ) + assert parsed_schedule == expected + + def test_asset_schedule_with_or_operator(self): + from airflow.sdk import Asset, AssetAny + + yaml_str = """ + or: + - uri: s3://dag1/output_1.txt + extra: + hi: bye + - uri: s3://dag2/output_1.txt + extra: + hi: bye + """ + data = yaml.safe_load(yaml_str) + parsed_schedule = DagBuilder._asset_schedule(data) + + expected = AssetAny( + Asset( + name="s3://dag1/output_1.txt", + uri="s3://dag1/output_1.txt", + group="asset", + extra={"hi": "bye"}, + watchers=[], + ), + Asset( + name="s3://dag2/output_1.txt", + uri="s3://dag2/output_1.txt", + group="asset", + extra={"hi": "bye"}, + watchers=[], + ), + ) + assert parsed_schedule == expected + + def test_asset_schedule_with_nested_operators(self): + from airflow.sdk import Asset, AssetAll, AssetAny + + yaml_str = """ + or: + - and: + - uri: s3://dag1/output_1.txt + extra: + hi: bye + - uri: s3://dag2/output_1.txt + extra: + hi: bye + - uri: s3://dag3/output_3.txt + extra: + hi: bye + """ + data = yaml.safe_load(yaml_str) + parsed_schedule = DagBuilder._asset_schedule(data) + + expected = AssetAny( + AssetAll( + Asset( + name="s3://dag1/output_1.txt", + uri="s3://dag1/output_1.txt", + group="asset", + extra={"hi": "bye"}, + watchers=[], + ), + Asset( + name="s3://dag2/output_1.txt", + uri="s3://dag2/output_1.txt", + group="asset", + extra={"hi": "bye"}, + watchers=[], + ), + ), + Asset( + name="s3://dag3/output_3.txt", + uri="s3://dag3/output_3.txt", + group="asset", + extra={"hi": "bye"}, + watchers=[], + ), + ) + assert parsed_schedule == expected + + def test_asset_schedule_with_watcher(self): + from airflow.providers.standard.triggers.file import FileDeleteTrigger + from airflow.sdk import Asset, AssetWatcher + + yaml_str = """ + - uri: s3://dag1/output_1.txt + extra: + hi: bye + watchers: + - callable: airflow.sdk.AssetWatcher + name: test_asset_watcher + trigger: + callable: airflow.providers.standard.triggers.file.FileDeleteTrigger + params: + filepath: "/temp/file.txt" + """ + data = yaml.safe_load(yaml_str) + parsed_schedule = DagBuilder._asset_schedule(data) + + expected = [ + Asset( + name="s3://dag1/output_1.txt", + uri="s3://dag1/output_1.txt", + group="asset", + extra={"hi": "bye"}, + watchers=[ + AssetWatcher( + name="test_asset_watcher", + trigger=FileDeleteTrigger(filepath="/temp/file.txt", poke_interval=5.0), + ) + ], + ) + ] + assert parsed_schedule == expected + + def test_resolve_schedule_cron_string(self): + yaml_str = "schedule: '* * * * *'" + data = yaml.safe_load(yaml_str) + schedule = DagBuilder._resolve_schedule(data) + assert schedule == "* * * * *" + + def test_resolve_schedule_cron_string_alias(self): + yaml_str = "schedule: '@daily'" + data = yaml.safe_load(yaml_str) + schedule = DagBuilder._resolve_schedule(data) + assert schedule == "@daily" + + def test_resolve_schedule_cron_type_value(self): + yaml_str = """ + schedule: + type: cron + value: "@daily" + """ + data = yaml.safe_load(yaml_str) + schedule = DagBuilder._resolve_schedule(data) + assert schedule == "@daily" + + def test_resolve_schedule_timetable_type(self): + from airflow.timetables.trigger import CronTriggerTimetable + + yaml_str = """ + schedule: + type: timetable + value: + callable: airflow.timetables.trigger.CronTriggerTimetable + params: + cron: "* * * * *" + timezone: UTC + """ + data = yaml.safe_load(yaml_str) + schedule = DagBuilder._resolve_schedule(data) + assert schedule == CronTriggerTimetable(cron="* * * * *", timezone="UTC") + + def test_resolve_schedule_timedelta_type(self): + yaml_str = """ + schedule: + type: timedelta + value: + seconds: 30 + """ + data = yaml.safe_load(yaml_str) + schedule = DagBuilder._resolve_schedule(data) + assert schedule == datetime.timedelta(seconds=30) + + def test_resolve_schedule_relativedelta_type(self): + from dateutil.relativedelta import relativedelta + + yaml_str = """ + schedule: + type: relativedelta + value: + month: 1 + """ + data = yaml.safe_load(yaml_str) + schedule = DagBuilder._resolve_schedule(data) + assert schedule == relativedelta(month=1) + + def test_resolve_schedule_asset_any_type(self): + from airflow.sdk import Asset, AssetAny + + yaml_str = """ + schedule: + type: assets + value: + or: + - uri: s3://dag1/output_1.txt + extra: + hi: bye + - uri: s3://dag2/output_1.txt + extra: + hi: bye + """ + data = yaml.safe_load(yaml_str) + schedule = DagBuilder._resolve_schedule(data) + + expected = AssetAny( + Asset( + name="s3://dag1/output_1.txt", + uri="s3://dag1/output_1.txt", + group="asset", + extra={"hi": "bye"}, + watchers=[], + ), + Asset( + name="s3://dag2/output_1.txt", + uri="s3://dag2/output_1.txt", + group="asset", + extra={"hi": "bye"}, + watchers=[], + ), + ) + assert schedule == expected + + class TestTopologicalSortTasks: def test_basic_topological_sort(self): From ee1848da14d04f551ea7738f78d9045747f57b31 Mon Sep 17 00:00:00 2001 From: pankajastro Date: Fri, 27 Jun 2025 17:07:02 +0530 Subject: [PATCH 2/3] Fix tests --- dagfactory/dagbuilder.py | 3 +++ tests/test_dagbuilder.py | 19 +++++++------------ 2 files changed, 10 insertions(+), 12 deletions(-) diff --git a/dagfactory/dagbuilder.py b/dagfactory/dagbuilder.py index 4400a4b3..ddfefea5 100644 --- a/dagfactory/dagbuilder.py +++ b/dagfactory/dagbuilder.py @@ -788,6 +788,9 @@ def _resolve_schedule(dag_params): if schedule is None: return None + if schedule.strip().lower() == "none": + return None + # Case 1: If schedule is a string, return it directly if isinstance(schedule, str): return schedule diff --git a/tests/test_dagbuilder.py b/tests/test_dagbuilder.py index 59d74d67..a20cf588 100644 --- a/tests/test_dagbuilder.py +++ b/tests/test_dagbuilder.py @@ -11,9 +11,11 @@ from airflow.sdk.definitions import DAG except ImportError: from airflow import DAG +import yaml +from airflow import DAG from packaging import version -from dagfactory.dagbuilder import DagBuilder, DagFactoryConfigException, Dataset +from dagfactory.dagbuilder import INSTALLED_AIRFLOW_VERSION, DagBuilder, DagFactoryConfigException, Dataset from tests.utils import ( get_bash_operator_path, get_http_sensor_path, @@ -23,13 +25,6 @@ one_hour_ago, ) -import yaml -from airflow import DAG -from packaging import version - -from dagfactory.dagbuilder import INSTALLED_AIRFLOW_VERSION, DagBuilder, DagFactoryConfigException, Dataset - - try: from airflow.providers.http.sensors.http import HttpSensor except ImportError: # Airflow < 2.4 @@ -1226,7 +1221,7 @@ def test_asset_schedule_with_and_operator(self): watchers=[], ), ) - assert parsed_schedule == expected + assert parsed_schedule.__eq__(expected) def test_asset_schedule_with_or_operator(self): from airflow.sdk import Asset, AssetAny @@ -1259,7 +1254,7 @@ def test_asset_schedule_with_or_operator(self): watchers=[], ), ) - assert parsed_schedule == expected + assert parsed_schedule.__eq__(expected) def test_asset_schedule_with_nested_operators(self): from airflow.sdk import Asset, AssetAll, AssetAny @@ -1305,7 +1300,7 @@ def test_asset_schedule_with_nested_operators(self): watchers=[], ), ) - assert parsed_schedule == expected + assert parsed_schedule.__eq__(expected) def test_asset_schedule_with_watcher(self): from airflow.providers.standard.triggers.file import FileDeleteTrigger @@ -1438,7 +1433,7 @@ def test_resolve_schedule_asset_any_type(self): watchers=[], ), ) - assert schedule == expected + assert schedule.__eq__(expected) class TestTopologicalSortTasks: From 27fd6ee3020124276a3936987e1bbba6cda783fe Mon Sep 17 00:00:00 2001 From: pankajastro Date: Fri, 27 Jun 2025 17:10:28 +0530 Subject: [PATCH 3/3] Fix tests --- dagfactory/dagbuilder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dagfactory/dagbuilder.py b/dagfactory/dagbuilder.py index ddfefea5..d1985682 100644 --- a/dagfactory/dagbuilder.py +++ b/dagfactory/dagbuilder.py @@ -788,7 +788,7 @@ def _resolve_schedule(dag_params): if schedule is None: return None - if schedule.strip().lower() == "none": + if isinstance(schedule, str) and schedule.strip().lower() == "none": return None # Case 1: If schedule is a string, return it directly