diff --git a/dagfactory/dagbuilder.py b/dagfactory/dagbuilder.py index 45a5791d..a38b9868 100644 --- a/dagfactory/dagbuilder.py +++ b/dagfactory/dagbuilder.py @@ -2,6 +2,8 @@ from __future__ import annotations +import ast + # pylint: disable=ungrouped-imports import inspect import os @@ -9,7 +11,7 @@ from copy import deepcopy from datetime import datetime, timedelta from functools import partial -from typing import Any, Callable, Dict, List, Union +from typing import Any, Callable, Dict, List, Tuple, Union from airflow import DAG, configuration from airflow.models import BaseOperator, Variable @@ -83,7 +85,7 @@ from airflow.utils.task_group import TaskGroup from kubernetes.client.models import V1Container, V1Pod -from dagfactory import utils +from dagfactory import parsers, utils from dagfactory.exceptions import DagFactoryConfigException, DagFactoryException # TimeTable is introduced in Airflow 2.2.0 @@ -293,8 +295,7 @@ def make_task(operator: str, task_params: Dict[str, Any]) -> BaseOperator: ) if not task_params.get("python_callable"): task_params["python_callable"]: Callable = utils.get_python_callable( - task_params["python_callable_name"], - task_params["python_callable_file"], + task_params["python_callable_name"], task_params["python_callable_file"] ) # remove dag-factory specific parameters # Airflow 2.0 doesn't allow these to be passed to operator @@ -312,8 +313,7 @@ def make_task(operator: str, task_params: Dict[str, Any]) -> BaseOperator: # Success checks if task_params.get("success_check_file") and task_params.get("success_check_name"): task_params["success"]: Callable = utils.get_python_callable( - task_params["success_check_name"], - task_params["success_check_file"], + task_params["success_check_name"], task_params["success_check_file"] ) del task_params["success_check_name"] del task_params["success_check_file"] @@ -325,8 +325,7 @@ def make_task(operator: str, task_params: Dict[str, Any]) -> BaseOperator: # Failure checks if task_params.get("failure_check_file") and task_params.get("failure_check_name"): task_params["failure"]: Callable = utils.get_python_callable( - task_params["failure_check_name"], - task_params["failure_check_file"], + task_params["failure_check_name"], task_params["failure_check_file"] ) del task_params["failure_check_name"] del task_params["failure_check_file"] @@ -347,8 +346,7 @@ def make_task(operator: str, task_params: Dict[str, Any]) -> BaseOperator: ) if task_params.get("response_check_file"): task_params["response_check"]: Callable = utils.get_python_callable( - task_params["response_check_name"], - task_params["response_check_file"], + task_params["response_check_name"], task_params["response_check_file"] ) # remove dag-factory specific parameters # Airflow 2.0 doesn't allow these to be passed to operator @@ -438,11 +436,7 @@ def make_task(operator: str, task_params: Dict[str, Any]) -> BaseOperator: utils.check_dict_key(task_params, "expand") or utils.check_dict_key(task_params, "partial") ) and version.parse(AIRFLOW_VERSION) >= version.parse("2.3.0"): # Getting expand and partial kwargs from task_params - ( - task_params, - expand_kwargs, - partial_kwargs, - ) = utils.get_expand_partial_kwargs(task_params) + (task_params, expand_kwargs, partial_kwargs) = utils.get_expand_partial_kwargs(task_params) # If there are partial_kwargs we should merge them with existing task_params if partial_kwargs and not utils.is_partial_duplicated(partial_kwargs, task_params): @@ -626,6 +620,132 @@ def replace_expand_values(task_conf: Dict, tasks_dict: Dict[str, BaseOperator]): task_conf["expand"][expand_key] = tasks_dict[task_id].output return task_conf + @staticmethod + def safe_eval(condition_string: str, dataset_map: dict) -> Any: + """ + Safely evaluates a condition string using the provided dataset map. + + :param condition_string: A string representing the condition to evaluate. + Example: "(dataset_custom_1 & dataset_custom_2) | dataset_custom_3". + :type condition_string: str + :param dataset_map: A dictionary where keys are valid variable names (dataset aliases), + and values are Dataset objects. + :type dataset_map: dict + + :returns: The result of evaluating the condition. + :rtype: Any + """ + tree = ast.parse(condition_string, mode="eval") + evaluator = parsers.SafeEvalVisitor(dataset_map) + return evaluator.evaluate(tree) + + @staticmethod + def _extract_and_transform_datasets(datasets_conditions: str) -> Tuple[str, Dict[str, Any]]: + """ + Extracts dataset names and storage paths from the conditions string and transforms them into valid variable names. + + :param datasets_conditions: A string of conditions dataset URIs to be evaluated in the condition. + :type datasets_conditions: str + + :returns: A tuple containing the transformed conditions string and the dataset map. + :rtype: Tuple[str, Dict[str, Any]] + """ + dataset_map = {} + datasets_filter: List[str] = utils.extract_dataset_names(datasets_conditions) + utils.extract_storage_names( + datasets_conditions + ) + + for uri in datasets_filter: + valid_variable_name = utils.make_valid_variable_name(uri) + datasets_conditions = datasets_conditions.replace(uri, valid_variable_name) + dataset_map[valid_variable_name] = Dataset(uri) + + return datasets_conditions, dataset_map + + @staticmethod + def evaluate_condition_with_datasets(datasets_conditions: str) -> Any: + """ + Evaluates a condition using the dataset filter, transforming URIs into valid variable names. + + :param datasets_conditions: A string of conditions dataset URIs to be evaluated in the condition. + :type datasets_conditions: str + + :returns: The result of the logical condition evaluation with URIs replaced by valid variable names. + :rtype: Any + """ + datasets_conditions, dataset_map = DagBuilder._extract_and_transform_datasets(datasets_conditions) + evaluated_condition = DagBuilder.safe_eval(datasets_conditions, dataset_map) + return evaluated_condition + + @staticmethod + def process_file_with_datasets(file: str, datasets_conditions: str) -> Any: + """ + Processes datasets from a file and evaluates conditions if provided. + + :param file: The file path containing dataset information in a YAML or other structured format. + :type file: str + :param datasets_conditions: A string of dataset conditions to filter and process. + :type datasets_conditions: str + + :returns: The result of the condition evaluation if `condition_string` is provided, otherwise a list of `Dataset` objects. + :rtype: Any + """ + is_airflow_version_at_least_2_9 = version.parse(AIRFLOW_VERSION) >= version.parse("2.9.0") + datasets_conditions, dataset_map = DagBuilder._extract_and_transform_datasets(datasets_conditions) + + if is_airflow_version_at_least_2_9: + map_datasets = utils.get_datasets_map_uri_yaml_file(file, list(dataset_map.keys())) + dataset_map = {alias_dataset: Dataset(uri) for alias_dataset, uri in map_datasets.items()} + evaluated_condition = DagBuilder.safe_eval(datasets_conditions, dataset_map) + return evaluated_condition + else: + datasets_uri = utils.get_datasets_uri_yaml_file(file, list(dataset_map.keys())) + return [Dataset(uri) for uri in datasets_uri] + + @staticmethod + def configure_schedule(dag_params: Dict[str, Any], dag_kwargs: Dict[str, Any]) -> None: + """ + Configures the schedule for the DAG based on parameters and the Airflow version. + + :param dag_params: A dictionary containing DAG parameters, including scheduling configuration. + Example: {"schedule": {"file": "datasets.yaml", "datasets": ["dataset_1"], "conditions": "dataset_1 & dataset_2"}} + :type dag_params: Dict[str, Any] + :param dag_kwargs: A dictionary for setting the resulting schedule configuration for the DAG. + :type dag_kwargs: Dict[str, Any] + + :raises KeyError: If required keys like "schedule" or "datasets" are missing in the parameters. + :returns: None. The function updates `dag_kwargs` in-place. + """ + is_airflow_version_at_least_2_4 = version.parse(AIRFLOW_VERSION) >= version.parse("2.4.0") + is_airflow_version_at_least_2_9 = version.parse(AIRFLOW_VERSION) >= version.parse("2.9.0") + has_schedule_attr = utils.check_dict_key(dag_params, "schedule") + has_schedule_interval_attr = utils.check_dict_key(dag_params, "schedule_interval") + + if has_schedule_attr and not has_schedule_interval_attr and is_airflow_version_at_least_2_4: + schedule: Dict[str, Any] = dag_params.get("schedule") + + has_file_attr = utils.check_dict_key(schedule, "file") + has_datasets_attr = utils.check_dict_key(schedule, "datasets") + + if has_file_attr and has_datasets_attr: + file = schedule.get("file") + datasets: Union[List[str], str] = schedule.get("datasets") + datasets_conditions: str = utils.parse_list_datasets(datasets) + dag_kwargs["schedule"] = DagBuilder.process_file_with_datasets(file, datasets_conditions) + + elif has_datasets_attr and is_airflow_version_at_least_2_9: + datasets = schedule["datasets"] + datasets_conditions: str = utils.parse_list_datasets(datasets) + dag_kwargs["schedule"] = DagBuilder.evaluate_condition_with_datasets(datasets_conditions) + + else: + dag_kwargs["schedule"] = [Dataset(uri) for uri in schedule] + + if has_file_attr: + schedule.pop("file") + if has_datasets_attr: + schedule.pop("datasets") + # pylint: disable=too-many-locals def build(self) -> Dict[str, Union[str, DAG]]: """ @@ -649,8 +769,7 @@ def build(self) -> Dict[str, Union[str, DAG]]: if version.parse(AIRFLOW_VERSION) >= version.parse("2.2.0"): dag_kwargs["max_active_tasks"] = dag_params.get( - "max_active_tasks", - configuration.conf.getint("core", "max_active_tasks_per_dag"), + "max_active_tasks", configuration.conf.getint("core", "max_active_tasks_per_dag") ) if dag_params.get("timetable"): @@ -668,8 +787,7 @@ def build(self) -> Dict[str, Union[str, DAG]]: ) dag_kwargs["max_active_runs"] = dag_params.get( - "max_active_runs", - configuration.conf.getint("core", "max_active_runs_per_dag"), + "max_active_runs", configuration.conf.getint("core", "max_active_runs_per_dag") ) dag_kwargs["dagrun_timeout"] = dag_params.get("dagrun_timeout", None) @@ -702,24 +820,7 @@ def build(self) -> Dict[str, Union[str, DAG]]: dag_kwargs["is_paused_upon_creation"] = dag_params.get("is_paused_upon_creation", None) - if ( - utils.check_dict_key(dag_params, "schedule") - and not utils.check_dict_key(dag_params, "schedule_interval") - and version.parse(AIRFLOW_VERSION) >= version.parse("2.4.0") - ): - if utils.check_dict_key(dag_params["schedule"], "file") and utils.check_dict_key( - dag_params["schedule"], "datasets" - ): - file = dag_params["schedule"]["file"] - datasets_filter = dag_params["schedule"]["datasets"] - datasets_uri = utils.get_datasets_uri_yaml_file(file, datasets_filter) - - del dag_params["schedule"]["file"] - del dag_params["schedule"]["datasets"] - else: - datasets_uri = dag_params["schedule"] - - dag_kwargs["schedule"] = [Dataset(uri) for uri in datasets_uri] + DagBuilder.configure_schedule(dag_params, dag_kwargs) dag_kwargs["params"] = dag_params.get("params", None) @@ -734,8 +835,7 @@ def build(self) -> Dict[str, Union[str, DAG]]: if dag_params.get("doc_md_python_callable_file") and dag_params.get("doc_md_python_callable_name"): doc_md_callable = utils.get_python_callable( - dag_params.get("doc_md_python_callable_name"), - dag_params.get("doc_md_python_callable_file"), + dag_params.get("doc_md_python_callable_name"), dag_params.get("doc_md_python_callable_file") ) dag.doc_md = doc_md_callable(**dag_params.get("doc_md_python_arguments", {})) @@ -872,8 +972,7 @@ def adjust_general_task_params(task_params: dict(str, Any)): task_params, "execution_date_fn_file" ): task_params["execution_date_fn"]: Callable = utils.get_python_callable( - task_params["execution_date_fn_name"], - task_params["execution_date_fn_file"], + task_params["execution_date_fn_name"], task_params["execution_date_fn_file"] ) del task_params["execution_date_fn_name"] del task_params["execution_date_fn_file"] @@ -937,8 +1036,7 @@ def make_decorator( # Fetch the Python callable if set(mandatory_keys_set1).issubset(task_params): python_callable: Callable = utils.get_python_callable( - task_params["python_callable_name"], - task_params["python_callable_file"], + task_params["python_callable_name"], task_params["python_callable_file"] ) # Remove dag-factory specific parameters since Airflow 2.0 doesn't allow these to be passed to operator del task_params["python_callable_name"] diff --git a/dagfactory/dagfactory.py b/dagfactory/dagfactory.py index 6bb485b8..a555f13f 100644 --- a/dagfactory/dagfactory.py +++ b/dagfactory/dagfactory.py @@ -92,15 +92,21 @@ def __join(loader: yaml.FullLoader, node: yaml.Node) -> str: seq = loader.construct_sequence(node) return "".join([str(i) for i in seq]) + def __or(loader: yaml.FullLoader, node: yaml.Node) -> str: + seq = loader.construct_sequence(node) + return " | ".join([f"({str(i)})" for i in seq]) + + def __and(loader: yaml.FullLoader, node: yaml.Node) -> str: + seq = loader.construct_sequence(node) + return " & ".join([f"({str(i)})" for i in seq]) + yaml.add_constructor("!join", __join, yaml.FullLoader) + yaml.add_constructor("!or", __or, yaml.FullLoader) + yaml.add_constructor("!and", __and, yaml.FullLoader) with open(config_filepath, "r", encoding="utf-8") as fp: - yaml.add_constructor("!join", __join, yaml.FullLoader) config_with_env = os.path.expandvars(fp.read()) - config: Dict[str, Any] = yaml.load( - stream=config_with_env, - Loader=yaml.FullLoader, - ) + config: Dict[str, Any] = yaml.load(stream=config_with_env, Loader=yaml.FullLoader) except Exception as err: raise DagFactoryConfigException("Invalid DAG Factory config file") from err return config diff --git a/dagfactory/parsers.py b/dagfactory/parsers.py new file mode 100644 index 00000000..4da385ae --- /dev/null +++ b/dagfactory/parsers.py @@ -0,0 +1,34 @@ +import ast + + +class SafeEvalVisitor(ast.NodeVisitor): + def __init__(self, dataset_map): + self.dataset_map = dataset_map + + def evaluate(self, tree): + return self.visit(tree) + + def visit_Expression(self, node): + return self.visit(node.body) + + def visit_BinOp(self, node): + left = self.visit(node.left) + right = self.visit(node.right) + + if isinstance(node.op, ast.BitAnd): + return left & right + elif isinstance(node.op, ast.BitOr): + return left | right + else: + raise ValueError(f"Unsupported binary operation: {type(node.op).__name__}") + + def visit_Name(self, node): + if node.id in self.dataset_map: + return self.dataset_map[node.id] + raise NameError(f"Undefined variable: {node.id}") + + def visit_Constant(self, node): + return node.value + + def generic_visit(self, node): + raise ValueError(f"Unsupported syntax: {type(node).__name__}") diff --git a/dagfactory/utils.py b/dagfactory/utils.py index c046be19..0d5eb08d 100644 --- a/dagfactory/utils.py +++ b/dagfactory/utils.py @@ -202,11 +202,9 @@ def check_template_searchpath(template_searchpath: Union[str, List[str]]) -> boo return False -def get_expand_partial_kwargs(task_params: Dict[str, Any]) -> Tuple[ - Dict[str, Any], - Dict[str, Union[Dict[str, Any], Any]], - Dict[str, Union[Dict[str, Any], Any]], -]: +def get_expand_partial_kwargs( + task_params: Dict[str, Any] +) -> Tuple[Dict[str, Any], Dict[str, Union[Dict[str, Any], Any]], Dict[str, Union[Dict[str, Any], Any]]]: """ Getting expand and partial kwargs if existed from task_params :param task_params: a dictionary with original task params from yaml @@ -273,3 +271,52 @@ def get_datasets_uri_yaml_file(file_path: str, datasets_filter: str) -> List[str except FileNotFoundError: logging.error("Error: File '%s' not found.", file_path) raise + + +def get_datasets_map_uri_yaml_file(file_path: str, datasets_filter: str) -> Dict[str, str]: + """ + Retrieves the URIs of datasets from a YAML file based on a given filter. + + :param file_path: The path to the YAML file. + :type file_path: str + :param datasets_filter: A list of dataset names to filter the results. + :type datasets_filter: List[str] + :return: A Dict of dataset URIs that match the filter. + :rtype: Dict[str, str] + """ + try: + with open(file_path, "r", encoding="UTF-8") as file: + data = yaml.safe_load(file) + + datasets = data.get("datasets", []) + datasets_result_dict = { + dataset["name"]: dataset["uri"] + for dataset in datasets + if dataset["name"] in datasets_filter and "uri" in dataset + } + return datasets_result_dict + except FileNotFoundError: + logging.error("Error: File '%s' not found.", file_path) + raise + + +def extract_dataset_names(expression) -> List[str]: + dataset_pattern = r"\b[a-zA-Z_][a-zA-Z0-9_]*\b" + datasets = re.findall(dataset_pattern, expression) + return datasets + + +def extract_storage_names(expression) -> List[str]: + storage_pattern = r"[a-zA-Z][a-zA-Z0-9+.-]*://[a-zA-Z0-9\-_/\.]+" + storages = re.findall(storage_pattern, expression) + return storages + + +def make_valid_variable_name(uri) -> str: + return re.sub(r"\W|^(?=\d)", "_", uri) + + +def parse_list_datasets(datasets: Union[List[str], str]) -> str: + if isinstance(datasets, list): + datasets = " & ".join(datasets) + return datasets diff --git a/dev/dags/datasets/example_dag_datasets.yml b/dev/dags/datasets/example_dag_datasets.yml index e9613ff5..ec14def9 100644 --- a/dev/dags/datasets/example_dag_datasets.yml +++ b/dev/dags/datasets/example_dag_datasets.yml @@ -52,3 +52,27 @@ example_custom_config_dataset_consumer_dag: task_1: operator: airflow.operators.bash_operator.BashOperator bash_command: "echo 'consumer datasets'" + +example_custom_config_condition_dataset_consumer_dag: + description: "Example DAG consumer custom config condition datasets" + schedule: + file: $CONFIG_ROOT_DIR/datasets/example_config_datasets.yml + datasets: "((dataset_custom_1 & dataset_custom_2) | dataset_custom_3)" + tasks: + task_1: + operator: airflow.operators.bash_operator.BashOperator + bash_command: "echo 'consumer datasets'" + +example_without_custom_config_condition_dataset_consumer_dag: + description: "Example DAG consumer custom config condition datasets" + schedule: + datasets: + !or + - !and + - "s3://bucket-cjmm/raw/dataset_custom_1" + - "s3://bucket-cjmm/raw/dataset_custom_2" + - "s3://bucket-cjmm/raw/dataset_custom_3" + tasks: + task_1: + operator: airflow.operators.bash_operator.BashOperator + bash_command: "echo 'consumer datasets'" \ No newline at end of file diff --git a/tests/test_parsers.py b/tests/test_parsers.py new file mode 100644 index 00000000..759ecfee --- /dev/null +++ b/tests/test_parsers.py @@ -0,0 +1,74 @@ +import ast +import pytest +from dagfactory.parsers import SafeEvalVisitor + +@pytest.fixture +def dataset_map(): + return { + 'dataset_custom_1': 1, + 'dataset_custom_2': 2, + 'dataset_custom_3': 3 + } + +@pytest.fixture +def visitor(dataset_map): + return SafeEvalVisitor(dataset_map) + +def test_evaluate(visitor): + condition_string = "dataset_custom_1 & dataset_custom_2 | dataset_custom_3" + tree = ast.parse(condition_string, mode='eval') + result = visitor.evaluate(tree) + expected = (1 & 2) | 3 + assert result == expected + +def test_visit_BinOp_and(visitor): + condition_string = "dataset_custom_1 & dataset_custom_2" + tree = ast.parse(condition_string, mode='eval') + result = visitor.evaluate(tree) + expected = 1 & 2 + assert result == expected + +def test_visit_BinOp_or(visitor): + condition_string = "dataset_custom_1 | dataset_custom_3" + tree = ast.parse(condition_string, mode='eval') + result = visitor.evaluate(tree) + expected = 1 | 3 + assert result == expected + +def test_visit_Name(visitor): + condition_string = "dataset_custom_2" + tree = ast.parse(condition_string, mode='eval') + result = visitor.evaluate(tree) + expected = 2 + assert result == expected + +def test_visit_Constant(visitor): + condition_string = "42" + tree = ast.parse(condition_string, mode='eval') + result = visitor.evaluate(tree) + expected = 42 + assert result == expected + +def test_unsupported_binary_operation(visitor): + condition_string = "dataset_custom_1 + dataset_custom_2" + tree = ast.parse(condition_string, mode='eval') + with pytest.raises(ValueError): + visitor.evaluate(tree) + +def test_unsupported_unary_operation(visitor): + condition_string = "+dataset_custom_1" + tree = ast.parse(condition_string, mode='eval') + with pytest.raises(ValueError): + visitor.evaluate(tree) + +def test_undefined_variable(visitor): + condition_string = "undefined_dataset" + tree = ast.parse(condition_string, mode='eval') + with pytest.raises(NameError): + visitor.evaluate(tree) + +def test_unsupported_syntax(visitor): + condition_string = "[1, 2, 3]" + tree = ast.parse(condition_string, mode='eval') + with pytest.raises(ValueError): + visitor.evaluate(tree) \ No newline at end of file diff --git a/tests/test_utils.py b/tests/test_utils.py index 28fe9c33..de5a1c5d 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -259,6 +259,32 @@ def test_open_and_filter_yaml_config_datasets(): assert actual == expected +def get_datasets_map_uri_yaml_file(): + datasets_names = ["dataset_custom_1", "dataset_custom_2"] + file_path = "dev/dags/datasets/example_config_datasets.yml" + + actual = utils.get_datasets_uri_yaml_file(file_path, datasets_names) + expected = { + "dataset_custom_1": "s3://bucket-cjmm/raw/dataset_custom_1", + "dataset_custom_2": "s3://bucket-cjmm/raw/dataset_custom_2", + } + + assert actual == expected + +def test_valid_uri(): + actual = utils.make_valid_variable_name("s3://bucket/dataset") + expected = "s3___bucket_dataset" + assert actual == expected + +def test_uri_with_special_characters(): + actual = utils.make_valid_variable_name("s3://bucket/dataset-1!@#$%^&*()") + expected = "s3___bucket_dataset_1__________" + assert actual == expected + +def test_uri_starting_with_number(): + actual = utils.make_valid_variable_name("123/bucket/dataset") + expected = "_123_bucket_dataset" + assert actual == expected def test_open_and_filter_yaml_config_datasets_file_notfound(): datasets_names = ["dataset_custom_1", "dataset_custom_2"] @@ -266,3 +292,35 @@ def test_open_and_filter_yaml_config_datasets_file_notfound(): with pytest.raises(Exception): utils.get_datasets_uri_yaml_file(file_path, datasets_names) + +def test_extract_dataset_names(): + expression = "((dataset_custom_1 & dataset_custom_2) | (dataset_custom_3))" + expected = ["dataset_custom_1", "dataset_custom_2", "dataset_custom_3"] + result = utils.extract_dataset_names(expression) + assert result == expected + + expression = "dataset1 | dataset2 & dataset3" + expected = ["dataset1", "dataset2", "dataset3"] + result = utils.extract_dataset_names(expression) + assert result == expected + + expression = "123_invalid_dataset" + expected = [] + result = utils.extract_dataset_names(expression) + assert result == expected + +def test_extract_storage_names(): + expression = "s3://bucket-cjmm/raw/dataset_custom_1 & s3://bucket-cjmm/raw/dataset_custom_2" + expected = ["s3://bucket-cjmm/raw/dataset_custom_1", "s3://bucket-cjmm/raw/dataset_custom_2"] + result = utils.extract_storage_names(expression) + assert result == expected + + expression = "gs://bucket-name/path/to/data | s3://another-bucket/path" + expected = ["gs://bucket-name/path/to/data", "s3://another-bucket/path"] + result = utils.extract_storage_names(expression) + assert result == expected + + expression = "no_storage_paths_here" + expected = [] + result = utils.extract_storage_names(expression) + assert result == expected