Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
4ba99db
Refactor and add support for schedule conditions in DAG configuration:
hiroyukiiseo-hue Dec 16, 2024
4413aad
add utils unit tests
hiroyukiiseo-hue Dec 17, 2024
5eddb01
fix unit test
hiroyukiiseo-hue Dec 17, 2024
ba3ea4b
fix:
hiroyukiiseo-hue Dec 17, 2024
daa5502
feat: add support for processing schedules with conditions and datasets
hiroyukiiseo-hue Dec 19, 2024
74effa6
lint
hiroyukiiseo-hue Dec 19, 2024
84c3172
feat: enhance schedule processing with conditions and datasets
hiroyukiiseo-hue Dec 19, 2024
08d1fa7
fix unit test
hiroyukiiseo-hue Dec 19, 2024
89854a8
fix ruff
hiroyukiiseo-hue Dec 19, 2024
6d26a15
Merge branch 'main' into feat/enable_schedule_dataset_condition
tatiana Jan 2, 2025
995931f
Format __all__ declaration for consistency
hiroyukiiseo-hue Jan 2, 2025
8e2068f
Refactor dataset conditions in example DAG configurations for improve…
hiroyukiiseo-hue Jan 2, 2025
0197264
Add SafeEvalVisitor class for safe AST evaluation of dataset expressions
hiroyukiiseo-hue Jan 2, 2025
db1d089
Add unit tests for SafeEvalVisitor to validate AST evaluation
hiroyukiiseo-hue Jan 2, 2025
7ed2a53
Add functions to extract dataset and storage names from expressions
hiroyukiiseo-hue Jan 2, 2025
e905f5f
Refactor condition evaluation methods in DagBuilder for improved safe…
hiroyukiiseo-hue Jan 2, 2025
6958e24
Lint
hiroyukiiseo-hue Jan 2, 2025
10150bb
Update __all__ in dagfactory and adjust expected output for invalid d…
hiroyukiiseo-hue Jan 2, 2025
cca2b54
Remove unsupported unary operation handling from SafeEvalVisitor
hiroyukiiseo-hue Jan 2, 2025
cf7df81
Fix formatting in __all__ declaration in dagfactory
hiroyukiiseo-hue Jan 2, 2025
4ecb2b6
Refactor dataset handling in DagBuilder to improve clarity and introd…
hiroyukiiseo-hue Jan 3, 2025
574b73a
Merge branch 'main' into feat/enable_schedule_dataset_condition
ErickSeo Jan 3, 2025
5955f18
Merge branch 'main' into feat/enable_schedule_dataset_condition
ErickSeo Jan 8, 2025
ac83b56
Merge branch 'main' into feat/enable_schedule_dataset_condition
tatiana Jan 10, 2025
fb33b67
Merge branch 'main' into feat/enable_schedule_dataset_condition
tatiana Jan 10, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
184 changes: 141 additions & 43 deletions dagfactory/dagbuilder.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,16 @@

from __future__ import annotations

import ast

# pylint: disable=ungrouped-imports
import inspect
import os
import re
from copy import deepcopy
from datetime import datetime, timedelta
from functools import partial
from typing import Any, Callable, Dict, List, Union
from typing import Any, Callable, Dict, List, Tuple, Union

from airflow import DAG, configuration
from airflow.models import BaseOperator, Variable
Expand Down Expand Up @@ -83,7 +85,7 @@
from airflow.utils.task_group import TaskGroup
from kubernetes.client.models import V1Container, V1Pod

from dagfactory import utils
from dagfactory import parsers, utils
from dagfactory.exceptions import DagFactoryConfigException, DagFactoryException

# TimeTable is introduced in Airflow 2.2.0
Expand Down Expand Up @@ -293,8 +295,7 @@ def make_task(operator: str, task_params: Dict[str, Any]) -> BaseOperator:
)
if not task_params.get("python_callable"):
task_params["python_callable"]: Callable = utils.get_python_callable(
task_params["python_callable_name"],
task_params["python_callable_file"],
task_params["python_callable_name"], task_params["python_callable_file"]
)
# remove dag-factory specific parameters
# Airflow 2.0 doesn't allow these to be passed to operator
Expand All @@ -312,8 +313,7 @@ def make_task(operator: str, task_params: Dict[str, Any]) -> BaseOperator:
# Success checks
if task_params.get("success_check_file") and task_params.get("success_check_name"):
task_params["success"]: Callable = utils.get_python_callable(
task_params["success_check_name"],
task_params["success_check_file"],
task_params["success_check_name"], task_params["success_check_file"]
)
del task_params["success_check_name"]
del task_params["success_check_file"]
Expand All @@ -325,8 +325,7 @@ def make_task(operator: str, task_params: Dict[str, Any]) -> BaseOperator:
# Failure checks
if task_params.get("failure_check_file") and task_params.get("failure_check_name"):
task_params["failure"]: Callable = utils.get_python_callable(
task_params["failure_check_name"],
task_params["failure_check_file"],
task_params["failure_check_name"], task_params["failure_check_file"]
)
del task_params["failure_check_name"]
del task_params["failure_check_file"]
Expand All @@ -347,8 +346,7 @@ def make_task(operator: str, task_params: Dict[str, Any]) -> BaseOperator:
)
if task_params.get("response_check_file"):
task_params["response_check"]: Callable = utils.get_python_callable(
task_params["response_check_name"],
task_params["response_check_file"],
task_params["response_check_name"], task_params["response_check_file"]
)
# remove dag-factory specific parameters
# Airflow 2.0 doesn't allow these to be passed to operator
Expand Down Expand Up @@ -438,11 +436,7 @@ def make_task(operator: str, task_params: Dict[str, Any]) -> BaseOperator:
utils.check_dict_key(task_params, "expand") or utils.check_dict_key(task_params, "partial")
) and version.parse(AIRFLOW_VERSION) >= version.parse("2.3.0"):
# Getting expand and partial kwargs from task_params
(
task_params,
expand_kwargs,
partial_kwargs,
) = utils.get_expand_partial_kwargs(task_params)
(task_params, expand_kwargs, partial_kwargs) = utils.get_expand_partial_kwargs(task_params)

# If there are partial_kwargs we should merge them with existing task_params
if partial_kwargs and not utils.is_partial_duplicated(partial_kwargs, task_params):
Expand Down Expand Up @@ -626,6 +620,132 @@ def replace_expand_values(task_conf: Dict, tasks_dict: Dict[str, BaseOperator]):
task_conf["expand"][expand_key] = tasks_dict[task_id].output
return task_conf

@staticmethod
def safe_eval(condition_string: str, dataset_map: dict) -> Any:
"""
Safely evaluates a condition string using the provided dataset map.

:param condition_string: A string representing the condition to evaluate.
Example: "(dataset_custom_1 & dataset_custom_2) | dataset_custom_3".
:type condition_string: str
:param dataset_map: A dictionary where keys are valid variable names (dataset aliases),
and values are Dataset objects.
:type dataset_map: dict

:returns: The result of evaluating the condition.
:rtype: Any
"""
tree = ast.parse(condition_string, mode="eval")
evaluator = parsers.SafeEvalVisitor(dataset_map)
return evaluator.evaluate(tree)

@staticmethod
def _extract_and_transform_datasets(datasets_conditions: str) -> Tuple[str, Dict[str, Any]]:
"""
Extracts dataset names and storage paths from the conditions string and transforms them into valid variable names.

:param datasets_conditions: A string of conditions dataset URIs to be evaluated in the condition.
:type datasets_conditions: str

:returns: A tuple containing the transformed conditions string and the dataset map.
:rtype: Tuple[str, Dict[str, Any]]
"""
dataset_map = {}
datasets_filter: List[str] = utils.extract_dataset_names(datasets_conditions) + utils.extract_storage_names(
datasets_conditions
)

for uri in datasets_filter:
valid_variable_name = utils.make_valid_variable_name(uri)
Comment thread
ErickSeo marked this conversation as resolved.
datasets_conditions = datasets_conditions.replace(uri, valid_variable_name)
dataset_map[valid_variable_name] = Dataset(uri)

return datasets_conditions, dataset_map

@staticmethod
def evaluate_condition_with_datasets(datasets_conditions: str) -> Any:
"""
Evaluates a condition using the dataset filter, transforming URIs into valid variable names.

:param datasets_conditions: A string of conditions dataset URIs to be evaluated in the condition.
:type datasets_conditions: str

:returns: The result of the logical condition evaluation with URIs replaced by valid variable names.
:rtype: Any
"""
datasets_conditions, dataset_map = DagBuilder._extract_and_transform_datasets(datasets_conditions)
evaluated_condition = DagBuilder.safe_eval(datasets_conditions, dataset_map)
return evaluated_condition

@staticmethod
def process_file_with_datasets(file: str, datasets_conditions: str) -> Any:
"""
Processes datasets from a file and evaluates conditions if provided.

:param file: The file path containing dataset information in a YAML or other structured format.
:type file: str
:param datasets_conditions: A string of dataset conditions to filter and process.
:type datasets_conditions: str

:returns: The result of the condition evaluation if `condition_string` is provided, otherwise a list of `Dataset` objects.
:rtype: Any
"""
is_airflow_version_at_least_2_9 = version.parse(AIRFLOW_VERSION) >= version.parse("2.9.0")
datasets_conditions, dataset_map = DagBuilder._extract_and_transform_datasets(datasets_conditions)

if is_airflow_version_at_least_2_9:
map_datasets = utils.get_datasets_map_uri_yaml_file(file, list(dataset_map.keys()))
dataset_map = {alias_dataset: Dataset(uri) for alias_dataset, uri in map_datasets.items()}
evaluated_condition = DagBuilder.safe_eval(datasets_conditions, dataset_map)
return evaluated_condition
else:
datasets_uri = utils.get_datasets_uri_yaml_file(file, list(dataset_map.keys()))
return [Dataset(uri) for uri in datasets_uri]

@staticmethod
def configure_schedule(dag_params: Dict[str, Any], dag_kwargs: Dict[str, Any]) -> None:
Comment thread
ErickSeo marked this conversation as resolved.
"""
Configures the schedule for the DAG based on parameters and the Airflow version.

:param dag_params: A dictionary containing DAG parameters, including scheduling configuration.
Example: {"schedule": {"file": "datasets.yaml", "datasets": ["dataset_1"], "conditions": "dataset_1 & dataset_2"}}
:type dag_params: Dict[str, Any]
:param dag_kwargs: A dictionary for setting the resulting schedule configuration for the DAG.
:type dag_kwargs: Dict[str, Any]

:raises KeyError: If required keys like "schedule" or "datasets" are missing in the parameters.
:returns: None. The function updates `dag_kwargs` in-place.
"""
is_airflow_version_at_least_2_4 = version.parse(AIRFLOW_VERSION) >= version.parse("2.4.0")
is_airflow_version_at_least_2_9 = version.parse(AIRFLOW_VERSION) >= version.parse("2.9.0")
has_schedule_attr = utils.check_dict_key(dag_params, "schedule")
has_schedule_interval_attr = utils.check_dict_key(dag_params, "schedule_interval")

if has_schedule_attr and not has_schedule_interval_attr and is_airflow_version_at_least_2_4:
schedule: Dict[str, Any] = dag_params.get("schedule")

has_file_attr = utils.check_dict_key(schedule, "file")
has_datasets_attr = utils.check_dict_key(schedule, "datasets")

if has_file_attr and has_datasets_attr:
file = schedule.get("file")
datasets: Union[List[str], str] = schedule.get("datasets")
datasets_conditions: str = utils.parse_list_datasets(datasets)
dag_kwargs["schedule"] = DagBuilder.process_file_with_datasets(file, datasets_conditions)

elif has_datasets_attr and is_airflow_version_at_least_2_9:
datasets = schedule["datasets"]
datasets_conditions: str = utils.parse_list_datasets(datasets)
dag_kwargs["schedule"] = DagBuilder.evaluate_condition_with_datasets(datasets_conditions)

else:
dag_kwargs["schedule"] = [Dataset(uri) for uri in schedule]

if has_file_attr:
schedule.pop("file")
if has_datasets_attr:
schedule.pop("datasets")

# pylint: disable=too-many-locals
def build(self) -> Dict[str, Union[str, DAG]]:
"""
Expand All @@ -649,8 +769,7 @@ def build(self) -> Dict[str, Union[str, DAG]]:

if version.parse(AIRFLOW_VERSION) >= version.parse("2.2.0"):
dag_kwargs["max_active_tasks"] = dag_params.get(
"max_active_tasks",
configuration.conf.getint("core", "max_active_tasks_per_dag"),
"max_active_tasks", configuration.conf.getint("core", "max_active_tasks_per_dag")
)

if dag_params.get("timetable"):
Expand All @@ -668,8 +787,7 @@ def build(self) -> Dict[str, Union[str, DAG]]:
)

dag_kwargs["max_active_runs"] = dag_params.get(
"max_active_runs",
configuration.conf.getint("core", "max_active_runs_per_dag"),
"max_active_runs", configuration.conf.getint("core", "max_active_runs_per_dag")
)

dag_kwargs["dagrun_timeout"] = dag_params.get("dagrun_timeout", None)
Expand Down Expand Up @@ -702,24 +820,7 @@ def build(self) -> Dict[str, Union[str, DAG]]:

dag_kwargs["is_paused_upon_creation"] = dag_params.get("is_paused_upon_creation", None)

if (
utils.check_dict_key(dag_params, "schedule")
and not utils.check_dict_key(dag_params, "schedule_interval")
and version.parse(AIRFLOW_VERSION) >= version.parse("2.4.0")
):
if utils.check_dict_key(dag_params["schedule"], "file") and utils.check_dict_key(
dag_params["schedule"], "datasets"
):
file = dag_params["schedule"]["file"]
datasets_filter = dag_params["schedule"]["datasets"]
datasets_uri = utils.get_datasets_uri_yaml_file(file, datasets_filter)

del dag_params["schedule"]["file"]
del dag_params["schedule"]["datasets"]
else:
datasets_uri = dag_params["schedule"]

dag_kwargs["schedule"] = [Dataset(uri) for uri in datasets_uri]
DagBuilder.configure_schedule(dag_params, dag_kwargs)

dag_kwargs["params"] = dag_params.get("params", None)

Expand All @@ -734,8 +835,7 @@ def build(self) -> Dict[str, Union[str, DAG]]:

if dag_params.get("doc_md_python_callable_file") and dag_params.get("doc_md_python_callable_name"):
doc_md_callable = utils.get_python_callable(
dag_params.get("doc_md_python_callable_name"),
dag_params.get("doc_md_python_callable_file"),
dag_params.get("doc_md_python_callable_name"), dag_params.get("doc_md_python_callable_file")
)
dag.doc_md = doc_md_callable(**dag_params.get("doc_md_python_arguments", {}))

Expand Down Expand Up @@ -872,8 +972,7 @@ def adjust_general_task_params(task_params: dict(str, Any)):
task_params, "execution_date_fn_file"
):
task_params["execution_date_fn"]: Callable = utils.get_python_callable(
task_params["execution_date_fn_name"],
task_params["execution_date_fn_file"],
task_params["execution_date_fn_name"], task_params["execution_date_fn_file"]
)
del task_params["execution_date_fn_name"]
del task_params["execution_date_fn_file"]
Expand Down Expand Up @@ -937,8 +1036,7 @@ def make_decorator(
# Fetch the Python callable
if set(mandatory_keys_set1).issubset(task_params):
python_callable: Callable = utils.get_python_callable(
task_params["python_callable_name"],
task_params["python_callable_file"],
task_params["python_callable_name"], task_params["python_callable_file"]
)
# Remove dag-factory specific parameters since Airflow 2.0 doesn't allow these to be passed to operator
del task_params["python_callable_name"]
Expand Down
16 changes: 11 additions & 5 deletions dagfactory/dagfactory.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,15 +92,21 @@ def __join(loader: yaml.FullLoader, node: yaml.Node) -> str:
seq = loader.construct_sequence(node)
return "".join([str(i) for i in seq])

def __or(loader: yaml.FullLoader, node: yaml.Node) -> str:
seq = loader.construct_sequence(node)
return " | ".join([f"({str(i)})" for i in seq])

def __and(loader: yaml.FullLoader, node: yaml.Node) -> str:
seq = loader.construct_sequence(node)
return " & ".join([f"({str(i)})" for i in seq])

yaml.add_constructor("!join", __join, yaml.FullLoader)
yaml.add_constructor("!or", __or, yaml.FullLoader)
yaml.add_constructor("!and", __and, yaml.FullLoader)

with open(config_filepath, "r", encoding="utf-8") as fp:
yaml.add_constructor("!join", __join, yaml.FullLoader)
config_with_env = os.path.expandvars(fp.read())
config: Dict[str, Any] = yaml.load(
stream=config_with_env,
Loader=yaml.FullLoader,
)
config: Dict[str, Any] = yaml.load(stream=config_with_env, Loader=yaml.FullLoader)
except Exception as err:
raise DagFactoryConfigException("Invalid DAG Factory config file") from err
return config
Expand Down
34 changes: 34 additions & 0 deletions dagfactory/parsers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import ast


class SafeEvalVisitor(ast.NodeVisitor):
def __init__(self, dataset_map):
self.dataset_map = dataset_map

def evaluate(self, tree):
return self.visit(tree)

def visit_Expression(self, node):
return self.visit(node.body)

def visit_BinOp(self, node):
left = self.visit(node.left)
right = self.visit(node.right)

if isinstance(node.op, ast.BitAnd):
return left & right
elif isinstance(node.op, ast.BitOr):
return left | right
else:
raise ValueError(f"Unsupported binary operation: {type(node.op).__name__}")

def visit_Name(self, node):
if node.id in self.dataset_map:
return self.dataset_map[node.id]
raise NameError(f"Undefined variable: {node.id}")

def visit_Constant(self, node):
return node.value

def generic_visit(self, node):
raise ValueError(f"Unsupported syntax: {type(node).__name__}")
Loading