From 4ba99db9ef6da72469028dca2c375cd0a0072f57 Mon Sep 17 00:00:00 2001 From: ErickSeo Date: Mon, 16 Dec 2024 18:04:21 -0300 Subject: [PATCH 01/20] Refactor and add support for schedule conditions in DAG configuration: - Added support for schedules defined by conditions, enabling dynamic scheduling based on dataset filters and conditions. - Introduced `configure_schedule` function to streamline DAG schedule setup based on Airflow version and parameters. - Created `process_file_with_datasets` function to handle dataset processing and conditional evaluation from files. - Implemented `evaluate_condition_with_datasets` to evaluate schedule conditions while ensuring valid variable names for dataset URIs. - Replaced repetitive code with reusable functions for better modularity and maintainability. - Enhanced code readability by adding detailed docstrings for all functions, following a standard format. - Improved safety by avoiding reliance on `globals()` in `evaluate_condition_with_datasets`. --- dagfactory/dagbuilder.py | 132 ++++++++++++++++++--- dagfactory/utils.py | 29 +++++ dev/dags/datasets/example_dag_datasets.yml | 21 ++++ 3 files changed, 163 insertions(+), 19 deletions(-) diff --git a/dagfactory/dagbuilder.py b/dagfactory/dagbuilder.py index 15a6fe17..99e15bcf 100644 --- a/dagfactory/dagbuilder.py +++ b/dagfactory/dagbuilder.py @@ -9,7 +9,7 @@ from copy import deepcopy from datetime import datetime, timedelta from functools import partial -from typing import Any, Callable, Dict, List, Union +from typing import Any, Callable, Dict, List, Union, Optional from airflow import DAG, configuration from airflow.models import BaseOperator, Variable @@ -623,6 +623,117 @@ def replace_expand_values(task_conf: Dict, tasks_dict: Dict[str, BaseOperator]): if task_id in tasks_dict: task_conf["expand"][expand_key] = tasks_dict[task_id].output return task_conf + + @staticmethod + def evaluate_condition_with_datasets( + condition_string: str, + datasets_filter: List[str] + ) -> Any: + """ + Evaluates a condition using the dataset filter, transforming URIs into valid variable names. + + :param condition_string: A string representing the logical condition to evaluate. + Example: "(dataset_custom_1 & dataset_custom_2) | dataset_custom_3". + :type condition_string: str + :param datasets_filter: A list of dataset URIs to be evaluated in the condition. + :type datasets_filter: List[str] + + :returns: The result of the logical condition evaluation with URIs replaced by valid variable names. + :rtype: Any + """ + dataset_map = {} + for uri in datasets_filter: + valid_variable_name = utils.make_valid_variable_name(uri) + condition_string = condition_string.replace(uri, valid_variable_name) + dataset_map[valid_variable_name] = Dataset(uri) + evaluated_condition = eval(condition_string, {}, dataset_map) + return evaluated_condition + + @staticmethod + def process_file_with_datasets( + file: str, datasets_filter: List[str], condition_string: Optional[str] = None + ) -> Any: + """ + Processes datasets from a file and evaluates conditions if provided. + + :param file: The file path containing dataset information in a YAML or other structured format. + :type file: str + :param datasets_filter: A list of dataset names to filter and process. + :type datasets_filter: List[str] + :param condition_string: A logical condition string to evaluate using the datasets. + If not provided, the function returns a list of `Dataset` objects based on the file and filter. + Example: "(dataset_custom_1 & dataset_custom_2) | dataset_custom_3". + :type condition_string: Optional[str] + + :returns: The result of the condition evaluation if `condition_string` is provided, otherwise a list of `Dataset` objects. + :rtype: Any + """ + if condition_string: + map_datasets = utils.get_datasets_map_uri_yaml_file(file, datasets_filter) + dataset_map = { + alias_dataset: Dataset(uri) for alias_dataset, uri in map_datasets.items() + } + return eval(condition_string, {}, dataset_map) + else: + datasets_uri = utils.get_datasets_uri_yaml_file(file, datasets_filter) + return [Dataset(uri) for uri in datasets_uri] + + @staticmethod + def configure_schedule(dag_params: Dict[str, Any], dag_kwargs: Dict[str, Any]) -> None: + """ + Configures the schedule for the DAG based on parameters and the Airflow version. + + :param dag_params: A dictionary containing DAG parameters, including scheduling configuration. + Example: {"schedule": {"file": "datasets.yaml", "datasets": ["dataset_1"], "conditions": "dataset_1 & dataset_2"}} + :type dag_params: Dict[str, Any] + :param dag_kwargs: A dictionary for setting the resulting schedule configuration for the DAG. + :type dag_kwargs: Dict[str, Any] + + :raises KeyError: If required keys like "schedule" or "datasets" are missing in the parameters. + :returns: None. The function updates `dag_kwargs` in-place. + """ + is_airflow_version_at_least_2_4 = version.parse(AIRFLOW_VERSION) >= version.parse("2.4.0") + has_schedule_attr = utils.check_dict_key(dag_params, "schedule") + has_schedule_interval_attr = utils.check_dict_key(dag_params, "schedule_interval") + + if has_schedule_attr and not has_schedule_interval_attr and is_airflow_version_at_least_2_4: + schedule: Dict[str, Any] = dag_params.get("schedule") + + has_file_attr = utils.check_dict_key(schedule, "file") + has_datasets_attr = utils.check_dict_key(schedule, "datasets") + has_conditions_attr = utils.check_dict_key(schedule, "conditions") + + if has_file_attr and has_datasets_attr: + file = schedule.get("file") + datasets_filter = schedule.get("datasets") + condition_string = schedule.get("conditions") if has_conditions_attr else None + + dag_kwargs["schedule"] = DagBuilder.process_file_with_datasets( + file, datasets_filter, condition_string + ) + + # Remove processed keys from the schedule + schedule.pop("file", None) + schedule.pop("datasets", None) + if has_conditions_attr: + schedule.pop("conditions", None) + + # Process condition-based schedules directly from datasets + elif has_conditions_attr and has_datasets_attr: + datasets_filter = schedule["datasets"] + condition_string = schedule["conditions"] + + # Evaluate the condition and set the schedule + dag_kwargs["schedule"] = DagBuilder.evaluate_condition_with_datasets( + condition_string, datasets_filter + ) + + # Remove the processed condition key + schedule.pop("conditions", None) + + # Handle basic schedules with direct dataset URIs + else: + dag_kwargs["schedule"] = [Dataset(uri) for uri in schedule] # pylint: disable=too-many-locals def build(self) -> Dict[str, Union[str, DAG]]: @@ -698,24 +809,7 @@ def build(self) -> Dict[str, Union[str, DAG]]: dag_kwargs["is_paused_upon_creation"] = dag_params.get("is_paused_upon_creation", None) - if ( - utils.check_dict_key(dag_params, "schedule") - and not utils.check_dict_key(dag_params, "schedule_interval") - and version.parse(AIRFLOW_VERSION) >= version.parse("2.4.0") - ): - if utils.check_dict_key(dag_params["schedule"], "file") and utils.check_dict_key( - dag_params["schedule"], "datasets" - ): - file = dag_params["schedule"]["file"] - datasets_filter = dag_params["schedule"]["datasets"] - datasets_uri = utils.get_datasets_uri_yaml_file(file, datasets_filter) - - del dag_params["schedule"]["file"] - del dag_params["schedule"]["datasets"] - else: - datasets_uri = dag_params["schedule"] - - dag_kwargs["schedule"] = [Dataset(uri) for uri in datasets_uri] + DagBuilder.configure_schedule(dag_params, dag_kwargs) dag_kwargs["params"] = dag_params.get("params", None) diff --git a/dagfactory/utils.py b/dagfactory/utils.py index c046be19..8d66ce87 100644 --- a/dagfactory/utils.py +++ b/dagfactory/utils.py @@ -273,3 +273,32 @@ def get_datasets_uri_yaml_file(file_path: str, datasets_filter: str) -> List[str except FileNotFoundError: logging.error("Error: File '%s' not found.", file_path) raise + +def get_datasets_map_uri_yaml_file(file_path: str, datasets_filter: str) -> Dict[str, str]: + """ + Retrieves the URIs of datasets from a YAML file based on a given filter. + + :param file_path: The path to the YAML file. + :type file_path: str + :param datasets_filter: A list of dataset names to filter the results. + :type datasets_filter: List[str] + :return: A Dict of dataset URIs that match the filter. + :rtype: Dict[str, str] + """ + try: + with open(file_path, "r", encoding="UTF-8") as file: + data = yaml.safe_load(file) + + datasets = data.get("datasets", []) + datasets_result_dict = { + dataset["name"]: dataset["uri"] + for dataset in datasets + if dataset["name"] in datasets_filter and "uri" in dataset + } + return datasets_result_dict + except FileNotFoundError: + logging.error("Error: File '%s' not found.", file_path) + raise + +def make_valid_variable_name(uri): + return re.sub(r'\W|^(?=\d)', '_', uri) \ No newline at end of file diff --git a/dev/dags/datasets/example_dag_datasets.yml b/dev/dags/datasets/example_dag_datasets.yml index e9613ff5..cdf8be04 100644 --- a/dev/dags/datasets/example_dag_datasets.yml +++ b/dev/dags/datasets/example_dag_datasets.yml @@ -52,3 +52,24 @@ example_custom_config_dataset_consumer_dag: task_1: operator: airflow.operators.bash_operator.BashOperator bash_command: "echo 'consumer datasets'" + +example_custom_config_condition_dataset_consumer_dag: + description: "Example DAG consumer custom config condition datasets" + schedule: + file: $CONFIG_ROOT_DIR/datasets/example_config_datasets.yml + datasets: ['dataset_custom_1', 'dataset_custom_2', 'dataset_custom_3'] + conditions: "((dataset_custom_1 & dataset_custom_2) | dataset_custom_3)" + tasks: + task_1: + operator: airflow.operators.bash_operator.BashOperator + bash_command: "echo 'consumer datasets'" + +example_without_custom_config_condition_dataset_consumer_dag: + description: "Example DAG consumer custom config condition datasets" + schedule: + datasets: ['s3://bucket-cjmm/raw/dataset_custom_1', 's3://bucket-cjmm/raw/dataset_custom_2', 's3://bucket-cjmm/raw/dataset_custom_3'] + conditions: "((s3://bucket-cjmm/raw/dataset_custom_1 & s3://bucket-cjmm/raw/dataset_custom_2) | s3://bucket-cjmm/raw/dataset_custom_3)" + tasks: + task_1: + operator: airflow.operators.bash_operator.BashOperator + bash_command: "echo 'consumer datasets'" \ No newline at end of file From 4413aada07ca7fd91714f30316df2d4ad8a7e49a Mon Sep 17 00:00:00 2001 From: ErickSeo Date: Tue, 17 Dec 2024 09:32:35 -0300 Subject: [PATCH 02/20] add utils unit tests --- tests/test_utils.py | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/tests/test_utils.py b/tests/test_utils.py index 28fe9c33..5fc56c7e 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -259,6 +259,32 @@ def test_open_and_filter_yaml_config_datasets(): assert actual == expected +def get_datasets_map_uri_yaml_file(): + datasets_names = ["dataset_custom_1", "dataset_custom_2"] + file_path = "dev/dags/datasets/example_config_datasets.yml" + + actual = utils.get_datasets_uri_yaml_file(file_path, datasets_names) + expected = { + "dataset_custom_1": "s3://bucket-cjmm/raw/dataset_custom_1", + "dataset_custom_2": "s3://bucket-cjmm/raw/dataset_custom_2", + } + + assert actual == expected + +def test_valid_uri(): + actual = utils.make_valid_variable_name("s3://bucket/dataset") + expected = "s3__bucket_dataset" + assert actual == expected + +def test_uri_with_special_characters(self): + actual = utils.make_valid_variable_name("s3://bucket/dataset-1!@#$%^&*()") + expected = "s3__bucket_dataset_1_____________" + assert actual == expected + +def test_uri_starting_with_number(self): + actual = utils.make_valid_variable_name("123/bucket/dataset") + expected = "_123_bucket_dataset" + assert actual == expected def test_open_and_filter_yaml_config_datasets_file_notfound(): datasets_names = ["dataset_custom_1", "dataset_custom_2"] From 5eddb018a11ce066546769e305a3c5fe004cfa7d Mon Sep 17 00:00:00 2001 From: ErickSeo Date: Tue, 17 Dec 2024 12:04:54 -0300 Subject: [PATCH 03/20] fix unit test --- tests/test_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_utils.py b/tests/test_utils.py index 5fc56c7e..550e3d8e 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -273,12 +273,12 @@ def get_datasets_map_uri_yaml_file(): def test_valid_uri(): actual = utils.make_valid_variable_name("s3://bucket/dataset") - expected = "s3__bucket_dataset" + expected = "s3___bucket_dataset" assert actual == expected def test_uri_with_special_characters(self): actual = utils.make_valid_variable_name("s3://bucket/dataset-1!@#$%^&*()") - expected = "s3__bucket_dataset_1_____________" + expected = "s3___bucket_dataset_1__________" assert actual == expected def test_uri_starting_with_number(self): From ba3ea4b0c48763059e4007176ab0d3474600e9cc Mon Sep 17 00:00:00 2001 From: ErickSeo Date: Tue, 17 Dec 2024 13:03:40 -0300 Subject: [PATCH 04/20] fix: - remove self from unit test --- tests/test_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_utils.py b/tests/test_utils.py index 550e3d8e..4234a5c9 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -276,12 +276,12 @@ def test_valid_uri(): expected = "s3___bucket_dataset" assert actual == expected -def test_uri_with_special_characters(self): +def test_uri_with_special_characters(): actual = utils.make_valid_variable_name("s3://bucket/dataset-1!@#$%^&*()") expected = "s3___bucket_dataset_1__________" assert actual == expected -def test_uri_starting_with_number(self): +def test_uri_starting_with_number(): actual = utils.make_valid_variable_name("123/bucket/dataset") expected = "_123_bucket_dataset" assert actual == expected From daa55022bcece2abe99d8eb91b636dcaa669c5a8 Mon Sep 17 00:00:00 2001 From: ErickSeo Date: Wed, 18 Dec 2024 21:45:29 -0300 Subject: [PATCH 05/20] feat: add support for processing schedules with conditions and datasets - Implemented logic to handle schedules with both file and datasets attributes. - Added support for evaluating conditions with datasets for Airflow version 2.9 and above. - Cleaned up schedule dictionary by removing processed keys. --- dagfactory/dagbuilder.py | 45 ++++++++++++++-------------------------- 1 file changed, 16 insertions(+), 29 deletions(-) diff --git a/dagfactory/dagbuilder.py b/dagfactory/dagbuilder.py index 99e15bcf..2bb1c395 100644 --- a/dagfactory/dagbuilder.py +++ b/dagfactory/dagbuilder.py @@ -623,12 +623,9 @@ def replace_expand_values(task_conf: Dict, tasks_dict: Dict[str, BaseOperator]): if task_id in tasks_dict: task_conf["expand"][expand_key] = tasks_dict[task_id].output return task_conf - + @staticmethod - def evaluate_condition_with_datasets( - condition_string: str, - datasets_filter: List[str] - ) -> Any: + def evaluate_condition_with_datasets(condition_string: str, datasets_filter: List[str]) -> Any: """ Evaluates a condition using the dataset filter, transforming URIs into valid variable names. @@ -670,14 +667,12 @@ def process_file_with_datasets( """ if condition_string: map_datasets = utils.get_datasets_map_uri_yaml_file(file, datasets_filter) - dataset_map = { - alias_dataset: Dataset(uri) for alias_dataset, uri in map_datasets.items() - } + dataset_map = {alias_dataset: Dataset(uri) for alias_dataset, uri in map_datasets.items()} return eval(condition_string, {}, dataset_map) else: datasets_uri = utils.get_datasets_uri_yaml_file(file, datasets_filter) return [Dataset(uri) for uri in datasets_uri] - + @staticmethod def configure_schedule(dag_params: Dict[str, Any], dag_kwargs: Dict[str, Any]) -> None: """ @@ -688,11 +683,12 @@ def configure_schedule(dag_params: Dict[str, Any], dag_kwargs: Dict[str, Any]) - :type dag_params: Dict[str, Any] :param dag_kwargs: A dictionary for setting the resulting schedule configuration for the DAG. :type dag_kwargs: Dict[str, Any] - + :raises KeyError: If required keys like "schedule" or "datasets" are missing in the parameters. :returns: None. The function updates `dag_kwargs` in-place. """ is_airflow_version_at_least_2_4 = version.parse(AIRFLOW_VERSION) >= version.parse("2.4.0") + is_airflow_version_at_least_2_9 = version.parse(AIRFLOW_VERSION) >= version.parse("2.9.0") has_schedule_attr = utils.check_dict_key(dag_params, "schedule") has_schedule_interval_attr = utils.check_dict_key(dag_params, "schedule_interval") @@ -706,35 +702,26 @@ def configure_schedule(dag_params: Dict[str, Any], dag_kwargs: Dict[str, Any]) - if has_file_attr and has_datasets_attr: file = schedule.get("file") datasets_filter = schedule.get("datasets") - condition_string = schedule.get("conditions") if has_conditions_attr else None - - dag_kwargs["schedule"] = DagBuilder.process_file_with_datasets( - file, datasets_filter, condition_string - ) + condition_string = schedule.get("conditions") - # Remove processed keys from the schedule - schedule.pop("file", None) - schedule.pop("datasets", None) - if has_conditions_attr: - schedule.pop("conditions", None) + dag_kwargs["schedule"] = DagBuilder.process_file_with_datasets(file, datasets_filter, condition_string) - # Process condition-based schedules directly from datasets elif has_conditions_attr and has_datasets_attr: datasets_filter = schedule["datasets"] condition_string = schedule["conditions"] - # Evaluate the condition and set the schedule - dag_kwargs["schedule"] = DagBuilder.evaluate_condition_with_datasets( - condition_string, datasets_filter - ) - - # Remove the processed condition key - schedule.pop("conditions", None) + if is_airflow_version_at_least_2_9: + dag_kwargs["schedule"] = DagBuilder.evaluate_condition_with_datasets( + condition_string, datasets_filter + ) - # Handle basic schedules with direct dataset URIs else: dag_kwargs["schedule"] = [Dataset(uri) for uri in schedule] + schedule.pop("file", None) + schedule.pop("datasets", None) + schedule.pop("conditions", None) + # pylint: disable=too-many-locals def build(self) -> Dict[str, Union[str, DAG]]: """ From 74effa624f80cd23eff1b632f541d8c34b3eec7c Mon Sep 17 00:00:00 2001 From: ErickSeo Date: Wed, 18 Dec 2024 21:45:42 -0300 Subject: [PATCH 06/20] lint --- dagfactory/utils.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/dagfactory/utils.py b/dagfactory/utils.py index 8d66ce87..c25c295a 100644 --- a/dagfactory/utils.py +++ b/dagfactory/utils.py @@ -274,6 +274,7 @@ def get_datasets_uri_yaml_file(file_path: str, datasets_filter: str) -> List[str logging.error("Error: File '%s' not found.", file_path) raise + def get_datasets_map_uri_yaml_file(file_path: str, datasets_filter: str) -> Dict[str, str]: """ Retrieves the URIs of datasets from a YAML file based on a given filter. @@ -291,8 +292,8 @@ def get_datasets_map_uri_yaml_file(file_path: str, datasets_filter: str) -> Dict datasets = data.get("datasets", []) datasets_result_dict = { - dataset["name"]: dataset["uri"] - for dataset in datasets + dataset["name"]: dataset["uri"] + for dataset in datasets if dataset["name"] in datasets_filter and "uri" in dataset } return datasets_result_dict @@ -300,5 +301,6 @@ def get_datasets_map_uri_yaml_file(file_path: str, datasets_filter: str) -> Dict logging.error("Error: File '%s' not found.", file_path) raise + def make_valid_variable_name(uri): - return re.sub(r'\W|^(?=\d)', '_', uri) \ No newline at end of file + return re.sub(r"\W|^(?=\d)", "_", uri) From 84c31728c0a2685b6826c5016d9c09e942b0b411 Mon Sep 17 00:00:00 2001 From: ErickSeo Date: Wed, 18 Dec 2024 21:52:58 -0300 Subject: [PATCH 07/20] feat: enhance schedule processing with conditions and datasets - Added logic to handle schedules with both file and datasets attributes. - Implemented support for evaluating conditions with datasets for Airflow version 2.9 and above. - Cleaned up schedule dictionary by removing processed keys after use. --- dagfactory/__init__.py | 5 +--- dagfactory/dagbuilder.py | 52 ++++++++++++++-------------------------- dagfactory/dagfactory.py | 9 ++----- dagfactory/utils.py | 8 +++---- 4 files changed, 24 insertions(+), 50 deletions(-) diff --git a/dagfactory/__init__.py b/dagfactory/__init__.py index b86179da..fece8933 100644 --- a/dagfactory/__init__.py +++ b/dagfactory/__init__.py @@ -3,7 +3,4 @@ from .dagfactory import DagFactory, load_yaml_dags __version__ = "0.21.0" -__all__ = [ - "DagFactory", - "load_yaml_dags", -] +__all__ = ["DagFactory", "load_yaml_dags"] diff --git a/dagfactory/dagbuilder.py b/dagfactory/dagbuilder.py index 2bb1c395..eabb177d 100644 --- a/dagfactory/dagbuilder.py +++ b/dagfactory/dagbuilder.py @@ -213,8 +213,7 @@ def get_dag_params(self) -> Dict[str, Any]: dag_params, "on_success_callback_file" ): dag_params["on_success_callback"]: Callable = utils.get_python_callable( - dag_params["on_success_callback_name"], - dag_params["on_success_callback_file"], + dag_params["on_success_callback_name"], dag_params["on_success_callback_file"] ) if utils.check_dict_key(dag_params, "on_failure_callback_name") and utils.check_dict_key( @@ -317,8 +316,7 @@ def make_task(operator: str, task_params: Dict[str, Any]) -> BaseOperator: ) if not task_params.get("python_callable"): task_params["python_callable"]: Callable = utils.get_python_callable( - task_params["python_callable_name"], - task_params["python_callable_file"], + task_params["python_callable_name"], task_params["python_callable_file"] ) # remove dag-factory specific parameters # Airflow 2.0 doesn't allow these to be passed to operator @@ -336,8 +334,7 @@ def make_task(operator: str, task_params: Dict[str, Any]) -> BaseOperator: # Success checks if task_params.get("success_check_file") and task_params.get("success_check_name"): task_params["success"]: Callable = utils.get_python_callable( - task_params["success_check_name"], - task_params["success_check_file"], + task_params["success_check_name"], task_params["success_check_file"] ) del task_params["success_check_name"] del task_params["success_check_file"] @@ -349,8 +346,7 @@ def make_task(operator: str, task_params: Dict[str, Any]) -> BaseOperator: # Failure checks if task_params.get("failure_check_file") and task_params.get("failure_check_name"): task_params["failure"]: Callable = utils.get_python_callable( - task_params["failure_check_name"], - task_params["failure_check_file"], + task_params["failure_check_name"], task_params["failure_check_file"] ) del task_params["failure_check_name"] del task_params["failure_check_file"] @@ -371,8 +367,7 @@ def make_task(operator: str, task_params: Dict[str, Any]) -> BaseOperator: ) if task_params.get("response_check_file"): task_params["response_check"]: Callable = utils.get_python_callable( - task_params["response_check_name"], - task_params["response_check_file"], + task_params["response_check_name"], task_params["response_check_file"] ) # remove dag-factory specific parameters # Airflow 2.0 doesn't allow these to be passed to operator @@ -461,11 +456,7 @@ def make_task(operator: str, task_params: Dict[str, Any]) -> BaseOperator: utils.check_dict_key(task_params, "expand") or utils.check_dict_key(task_params, "partial") ) and version.parse(AIRFLOW_VERSION) >= version.parse("2.3.0"): # Getting expand and partial kwargs from task_params - ( - task_params, - expand_kwargs, - partial_kwargs, - ) = utils.get_expand_partial_kwargs(task_params) + (task_params, expand_kwargs, partial_kwargs) = utils.get_expand_partial_kwargs(task_params) # If there are partial_kwargs we should merge them with existing task_params if partial_kwargs and not utils.is_partial_duplicated(partial_kwargs, task_params): @@ -505,12 +496,7 @@ def _init_task_group_callback_param(task_group_conf): return task_group_conf default_args = task_group_conf["default_args"] - callback_keys = [ - "on_success_callback", - "on_execute_callback", - "on_failure_callback", - "on_retry_callback", - ] + callback_keys = ["on_success_callback", "on_execute_callback", "on_failure_callback", "on_retry_callback"] for key in callback_keys: if key in default_args and isinstance(default_args[key], str): @@ -718,9 +704,12 @@ def configure_schedule(dag_params: Dict[str, Any], dag_kwargs: Dict[str, Any]) - else: dag_kwargs["schedule"] = [Dataset(uri) for uri in schedule] - schedule.pop("file", None) - schedule.pop("datasets", None) - schedule.pop("conditions", None) + if has_file_attr: + schedule.pop("file") + if has_datasets_attr: + schedule.pop("datasets") + if has_conditions_attr: + schedule.pop("conditions") # pylint: disable=too-many-locals def build(self) -> Dict[str, Union[str, DAG]]: @@ -743,8 +732,7 @@ def build(self) -> Dict[str, Union[str, DAG]]: if version.parse(AIRFLOW_VERSION) >= version.parse("2.2.0"): dag_kwargs["max_active_tasks"] = dag_params.get( - "max_active_tasks", - configuration.conf.getint("core", "max_active_tasks_per_dag"), + "max_active_tasks", configuration.conf.getint("core", "max_active_tasks_per_dag") ) if dag_params.get("timetable"): @@ -762,8 +750,7 @@ def build(self) -> Dict[str, Union[str, DAG]]: ) dag_kwargs["max_active_runs"] = dag_params.get( - "max_active_runs", - configuration.conf.getint("core", "max_active_runs_per_dag"), + "max_active_runs", configuration.conf.getint("core", "max_active_runs_per_dag") ) dag_kwargs["dagrun_timeout"] = dag_params.get("dagrun_timeout", None) @@ -811,8 +798,7 @@ def build(self) -> Dict[str, Union[str, DAG]]: if dag_params.get("doc_md_python_callable_file") and dag_params.get("doc_md_python_callable_name"): doc_md_callable = utils.get_python_callable( - dag_params.get("doc_md_python_callable_name"), - dag_params.get("doc_md_python_callable_file"), + dag_params.get("doc_md_python_callable_name"), dag_params.get("doc_md_python_callable_file") ) dag.doc_md = doc_md_callable(**dag_params.get("doc_md_python_arguments", {})) @@ -941,8 +927,7 @@ def adjust_general_task_params(task_params: dict(str, Any)): task_params, "execution_date_fn_file" ): task_params["execution_date_fn"]: Callable = utils.get_python_callable( - task_params["execution_date_fn_name"], - task_params["execution_date_fn_file"], + task_params["execution_date_fn_name"], task_params["execution_date_fn_file"] ) del task_params["execution_date_fn_name"] del task_params["execution_date_fn_file"] @@ -998,8 +983,7 @@ def make_decorator( # Fetch the Python callable if set(mandatory_keys_set1).issubset(task_params): python_callable: Callable = utils.get_python_callable( - task_params["python_callable_name"], - task_params["python_callable_file"], + task_params["python_callable_name"], task_params["python_callable_file"] ) # Remove dag-factory specific parameters since Airflow 2.0 doesn't allow these to be passed to operator del task_params["python_callable_name"] diff --git a/dagfactory/dagfactory.py b/dagfactory/dagfactory.py index 3495f4af..a0f2fc8e 100644 --- a/dagfactory/dagfactory.py +++ b/dagfactory/dagfactory.py @@ -84,10 +84,7 @@ def __join(loader: yaml.FullLoader, node: yaml.Node) -> str: with open(config_filepath, "r", encoding="utf-8") as fp: yaml.add_constructor("!join", __join, yaml.FullLoader) config_with_env = os.path.expandvars(fp.read()) - config: Dict[str, Any] = yaml.load( - stream=config_with_env, - Loader=yaml.FullLoader, - ) + config: Dict[str, Any] = yaml.load(stream=config_with_env, Loader=yaml.FullLoader) except Exception as err: raise DagFactoryConfigException("Invalid DAG Factory config file") from err return config @@ -177,9 +174,7 @@ def clean_dags(self, globals: Dict[str, Any]) -> None: def load_yaml_dags( - globals_dict: Dict[str, Any], - dags_folder: str = airflow_conf.get("core", "dags_folder"), - suffix=None, + globals_dict: Dict[str, Any], dags_folder: str = airflow_conf.get("core", "dags_folder"), suffix=None ): """ Loads all the yaml/yml files in the dags folder diff --git a/dagfactory/utils.py b/dagfactory/utils.py index c25c295a..1a358873 100644 --- a/dagfactory/utils.py +++ b/dagfactory/utils.py @@ -202,11 +202,9 @@ def check_template_searchpath(template_searchpath: Union[str, List[str]]) -> boo return False -def get_expand_partial_kwargs(task_params: Dict[str, Any]) -> Tuple[ - Dict[str, Any], - Dict[str, Union[Dict[str, Any], Any]], - Dict[str, Union[Dict[str, Any], Any]], -]: +def get_expand_partial_kwargs( + task_params: Dict[str, Any] +) -> Tuple[Dict[str, Any], Dict[str, Union[Dict[str, Any], Any]], Dict[str, Union[Dict[str, Any], Any]]]: """ Getting expand and partial kwargs if existed from task_params :param task_params: a dictionary with original task params from yaml From 08d1fa78d26ff346ad2ecdd761c375cae83f15ed Mon Sep 17 00:00:00 2001 From: ErickSeo Date: Wed, 18 Dec 2024 22:00:02 -0300 Subject: [PATCH 08/20] fix unit test --- dagfactory/dagbuilder.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/dagfactory/dagbuilder.py b/dagfactory/dagbuilder.py index eabb177d..c5b1f573 100644 --- a/dagfactory/dagbuilder.py +++ b/dagfactory/dagbuilder.py @@ -651,7 +651,8 @@ def process_file_with_datasets( :returns: The result of the condition evaluation if `condition_string` is provided, otherwise a list of `Dataset` objects. :rtype: Any """ - if condition_string: + is_airflow_version_at_least_2_9 = version.parse(AIRFLOW_VERSION) >= version.parse("2.9.0") + if condition_string and is_airflow_version_at_least_2_9: map_datasets = utils.get_datasets_map_uri_yaml_file(file, datasets_filter) dataset_map = {alias_dataset: Dataset(uri) for alias_dataset, uri in map_datasets.items()} return eval(condition_string, {}, dataset_map) @@ -692,14 +693,10 @@ def configure_schedule(dag_params: Dict[str, Any], dag_kwargs: Dict[str, Any]) - dag_kwargs["schedule"] = DagBuilder.process_file_with_datasets(file, datasets_filter, condition_string) - elif has_conditions_attr and has_datasets_attr: + elif has_conditions_attr and has_datasets_attr and is_airflow_version_at_least_2_9: datasets_filter = schedule["datasets"] condition_string = schedule["conditions"] - - if is_airflow_version_at_least_2_9: - dag_kwargs["schedule"] = DagBuilder.evaluate_condition_with_datasets( - condition_string, datasets_filter - ) + dag_kwargs["schedule"] = DagBuilder.evaluate_condition_with_datasets(condition_string, datasets_filter) else: dag_kwargs["schedule"] = [Dataset(uri) for uri in schedule] From 89854a876122df6bc005f8866fa25509b4e389c1 Mon Sep 17 00:00:00 2001 From: ErickSeo Date: Wed, 18 Dec 2024 22:06:57 -0300 Subject: [PATCH 09/20] fix ruff --- dagfactory/dagbuilder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dagfactory/dagbuilder.py b/dagfactory/dagbuilder.py index c5b1f573..311f41e6 100644 --- a/dagfactory/dagbuilder.py +++ b/dagfactory/dagbuilder.py @@ -9,7 +9,7 @@ from copy import deepcopy from datetime import datetime, timedelta from functools import partial -from typing import Any, Callable, Dict, List, Union, Optional +from typing import Any, Callable, Dict, List, Optional, Union from airflow import DAG, configuration from airflow.models import BaseOperator, Variable From 995931fd0034aa54dcbb7a0083ce46400134cee0 Mon Sep 17 00:00:00 2001 From: ErickSeo Date: Thu, 2 Jan 2025 19:56:13 -0300 Subject: [PATCH 10/20] Format __all__ declaration for consistency --- dagfactory/__init__.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/dagfactory/__init__.py b/dagfactory/__init__.py index fece8933..1d5e3a5b 100644 --- a/dagfactory/__init__.py +++ b/dagfactory/__init__.py @@ -3,4 +3,7 @@ from .dagfactory import DagFactory, load_yaml_dags __version__ = "0.21.0" -__all__ = ["DagFactory", "load_yaml_dags"] +__all__ = [ + "DagFactory", + "load_yaml_dags" +] From 8e2068fb4f688a2c124864c42be97f20b77c8fd1 Mon Sep 17 00:00:00 2001 From: ErickSeo Date: Thu, 2 Jan 2025 19:56:26 -0300 Subject: [PATCH 11/20] Refactor dataset conditions in example DAG configurations for improved clarity --- dev/dags/datasets/example_dag_datasets.yml | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/dev/dags/datasets/example_dag_datasets.yml b/dev/dags/datasets/example_dag_datasets.yml index cdf8be04..ec14def9 100644 --- a/dev/dags/datasets/example_dag_datasets.yml +++ b/dev/dags/datasets/example_dag_datasets.yml @@ -57,8 +57,7 @@ example_custom_config_condition_dataset_consumer_dag: description: "Example DAG consumer custom config condition datasets" schedule: file: $CONFIG_ROOT_DIR/datasets/example_config_datasets.yml - datasets: ['dataset_custom_1', 'dataset_custom_2', 'dataset_custom_3'] - conditions: "((dataset_custom_1 & dataset_custom_2) | dataset_custom_3)" + datasets: "((dataset_custom_1 & dataset_custom_2) | dataset_custom_3)" tasks: task_1: operator: airflow.operators.bash_operator.BashOperator @@ -67,8 +66,12 @@ example_custom_config_condition_dataset_consumer_dag: example_without_custom_config_condition_dataset_consumer_dag: description: "Example DAG consumer custom config condition datasets" schedule: - datasets: ['s3://bucket-cjmm/raw/dataset_custom_1', 's3://bucket-cjmm/raw/dataset_custom_2', 's3://bucket-cjmm/raw/dataset_custom_3'] - conditions: "((s3://bucket-cjmm/raw/dataset_custom_1 & s3://bucket-cjmm/raw/dataset_custom_2) | s3://bucket-cjmm/raw/dataset_custom_3)" + datasets: + !or + - !and + - "s3://bucket-cjmm/raw/dataset_custom_1" + - "s3://bucket-cjmm/raw/dataset_custom_2" + - "s3://bucket-cjmm/raw/dataset_custom_3" tasks: task_1: operator: airflow.operators.bash_operator.BashOperator From 0197264c772e9efb8be9fae420621a4e6269eb6f Mon Sep 17 00:00:00 2001 From: ErickSeo Date: Thu, 2 Jan 2025 19:56:36 -0300 Subject: [PATCH 12/20] Add SafeEvalVisitor class for safe AST evaluation of dataset expressions --- dagfactory/parsers.py | 41 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 41 insertions(+) create mode 100644 dagfactory/parsers.py diff --git a/dagfactory/parsers.py b/dagfactory/parsers.py new file mode 100644 index 00000000..642c57c3 --- /dev/null +++ b/dagfactory/parsers.py @@ -0,0 +1,41 @@ +import ast + + +class SafeEvalVisitor(ast.NodeVisitor): + def __init__(self, dataset_map): + self.dataset_map = dataset_map + + def evaluate(self, tree): + return self.visit(tree) + + def visit_Expression(self, node): + return self.visit(node.body) + + def visit_BinOp(self, node): + left = self.visit(node.left) + right = self.visit(node.right) + + if isinstance(node.op, ast.BitAnd): + return left & right + elif isinstance(node.op, ast.BitOr): + return left | right + else: + raise ValueError(f"Unsupported binary operation: {type(node.op).__name__}") + + def visit_UnaryOp(self, node): + operand = self.visit(node.operand) + if isinstance(node.op, ast.Not): + return ~operand + else: + raise ValueError(f"Unsupported unary operation: {type(node.op).__name__}") + + def visit_Name(self, node): + if node.id in self.dataset_map: + return self.dataset_map[node.id] + raise NameError(f"Undefined variable: {node.id}") + + def visit_Constant(self, node): + return node.value + + def generic_visit(self, node): + raise ValueError(f"Unsupported syntax: {type(node).__name__}") \ No newline at end of file From db1d089a2b2dc459a78a0eab54191bcb46e5d710 Mon Sep 17 00:00:00 2001 From: ErickSeo Date: Thu, 2 Jan 2025 19:56:42 -0300 Subject: [PATCH 13/20] Add unit tests for SafeEvalVisitor to validate AST evaluation --- tests/test_parsers.py | 81 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 81 insertions(+) create mode 100644 tests/test_parsers.py diff --git a/tests/test_parsers.py b/tests/test_parsers.py new file mode 100644 index 00000000..fe6af3e7 --- /dev/null +++ b/tests/test_parsers.py @@ -0,0 +1,81 @@ +import ast +import pytest +from dagfactory.parsers import SafeEvalVisitor + +@pytest.fixture +def dataset_map(): + return { + 'dataset_custom_1': 1, + 'dataset_custom_2': 2, + 'dataset_custom_3': 3 + } + +@pytest.fixture +def visitor(dataset_map): + return SafeEvalVisitor(dataset_map) + +def test_evaluate(visitor): + condition_string = "dataset_custom_1 & dataset_custom_2 | dataset_custom_3" + tree = ast.parse(condition_string, mode='eval') + result = visitor.evaluate(tree) + expected = (1 & 2) | 3 + assert result == expected + +def test_visit_BinOp_and(visitor): + condition_string = "dataset_custom_1 & dataset_custom_2" + tree = ast.parse(condition_string, mode='eval') + result = visitor.evaluate(tree) + expected = 1 & 2 + assert result == expected + +def test_visit_BinOp_or(visitor): + condition_string = "dataset_custom_1 | dataset_custom_3" + tree = ast.parse(condition_string, mode='eval') + result = visitor.evaluate(tree) + expected = 1 | 3 + assert result == expected + +def test_visit_UnaryOp_not(visitor): + condition_string = "~dataset_custom_1" + tree = ast.parse(condition_string, mode='eval') + result = visitor.evaluate(tree) + expected = ~1 + assert result == expected + +def test_visit_Name(visitor): + condition_string = "dataset_custom_2" + tree = ast.parse(condition_string, mode='eval') + result = visitor.evaluate(tree) + expected = 2 + assert result == expected + +def test_visit_Constant(visitor): + condition_string = "42" + tree = ast.parse(condition_string, mode='eval') + result = visitor.evaluate(tree) + expected = 42 + assert result == expected + +def test_unsupported_binary_operation(visitor): + condition_string = "dataset_custom_1 + dataset_custom_2" + tree = ast.parse(condition_string, mode='eval') + with pytest.raises(ValueError): + visitor.evaluate(tree) + +def test_unsupported_unary_operation(visitor): + condition_string = "+dataset_custom_1" + tree = ast.parse(condition_string, mode='eval') + with pytest.raises(ValueError): + visitor.evaluate(tree) + +def test_undefined_variable(visitor): + condition_string = "undefined_dataset" + tree = ast.parse(condition_string, mode='eval') + with pytest.raises(NameError): + visitor.evaluate(tree) + +def test_unsupported_syntax(visitor): + condition_string = "[1, 2, 3]" + tree = ast.parse(condition_string, mode='eval') + with pytest.raises(ValueError): + visitor.evaluate(tree) \ No newline at end of file From 7ed2a53a83ba14c5eb83eddfe5a5b1acbdfb40b4 Mon Sep 17 00:00:00 2001 From: ErickSeo Date: Thu, 2 Jan 2025 19:56:53 -0300 Subject: [PATCH 14/20] Add functions to extract dataset and storage names from expressions --- dagfactory/utils.py | 11 ++++++++++- tests/test_utils.py | 32 ++++++++++++++++++++++++++++++++ 2 files changed, 42 insertions(+), 1 deletion(-) diff --git a/dagfactory/utils.py b/dagfactory/utils.py index 1a358873..c9e98ed2 100644 --- a/dagfactory/utils.py +++ b/dagfactory/utils.py @@ -299,6 +299,15 @@ def get_datasets_map_uri_yaml_file(file_path: str, datasets_filter: str) -> Dict logging.error("Error: File '%s' not found.", file_path) raise +def extract_dataset_names(expression) -> List[str]: + dataset_pattern = r'\b[a-zA-Z_][a-zA-Z0-9_]*\b' + datasets = re.findall(dataset_pattern, expression) + return datasets -def make_valid_variable_name(uri): +def extract_storage_names(expression) -> List[str]: + storage_pattern = r'[a-zA-Z][a-zA-Z0-9+.-]*://[a-zA-Z0-9\-_/\.]+' + storages = re.findall(storage_pattern, expression) + return storages + +def make_valid_variable_name(uri) -> str: return re.sub(r"\W|^(?=\d)", "_", uri) diff --git a/tests/test_utils.py b/tests/test_utils.py index 4234a5c9..b7b50ce1 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -292,3 +292,35 @@ def test_open_and_filter_yaml_config_datasets_file_notfound(): with pytest.raises(Exception): utils.get_datasets_uri_yaml_file(file_path, datasets_names) + +def test_extract_dataset_names(): + expression = "((dataset_custom_1 & dataset_custom_2) | (dataset_custom_3))" + expected = ["dataset_custom_1", "dataset_custom_2", "dataset_custom_3"] + result = utils.extract_dataset_names(expression) + assert result == expected + + expression = "dataset1 | dataset2 & dataset3" + expected = ["dataset1", "dataset2", "dataset3"] + result = utils.extract_dataset_names(expression) + assert result == expected + + expression = "123_invalid_dataset" + expected = ["invalid_dataset"] + result = utils.extract_dataset_names(expression) + assert result == expected + +def test_extract_storage_names(): + expression = "s3://bucket-cjmm/raw/dataset_custom_1 & s3://bucket-cjmm/raw/dataset_custom_2" + expected = ["s3://bucket-cjmm/raw/dataset_custom_1", "s3://bucket-cjmm/raw/dataset_custom_2"] + result = utils.extract_storage_names(expression) + assert result == expected + + expression = "gs://bucket-name/path/to/data | s3://another-bucket/path" + expected = ["gs://bucket-name/path/to/data", "s3://another-bucket/path"] + result = utils.extract_storage_names(expression) + assert result == expected + + expression = "no_storage_paths_here" + expected = [] + result = utils.extract_storage_names(expression) + assert result == expected From e905f5f15a2089939a7574a79e097e0c908434b5 Mon Sep 17 00:00:00 2001 From: ErickSeo Date: Thu, 2 Jan 2025 19:57:00 -0300 Subject: [PATCH 15/20] Refactor condition evaluation methods in DagBuilder for improved safety and clarity --- dagfactory/dagbuilder.py | 75 ++++++++++++++++++++++++++-------------- dagfactory/dagfactory.py | 11 +++++- 2 files changed, 60 insertions(+), 26 deletions(-) diff --git a/dagfactory/dagbuilder.py b/dagfactory/dagbuilder.py index b6b28196..e862a9e5 100644 --- a/dagfactory/dagbuilder.py +++ b/dagfactory/dagbuilder.py @@ -2,6 +2,8 @@ from __future__ import annotations +import ast + # pylint: disable=ungrouped-imports import inspect import os @@ -9,7 +11,7 @@ from copy import deepcopy from datetime import datetime, timedelta from functools import partial -from typing import Any, Callable, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Union from airflow import DAG, configuration from airflow.models import BaseOperator, Variable @@ -83,7 +85,7 @@ from airflow.utils.task_group import TaskGroup from kubernetes.client.models import V1Container, V1Pod -from dagfactory import utils +from dagfactory import parsers, utils from dagfactory.exceptions import DagFactoryConfigException, DagFactoryException # TimeTable is introduced in Airflow 2.2.0 @@ -611,51 +613,80 @@ def replace_expand_values(task_conf: Dict, tasks_dict: Dict[str, BaseOperator]): return task_conf @staticmethod - def evaluate_condition_with_datasets(condition_string: str, datasets_filter: List[str]) -> Any: + def safe_eval(condition_string: str, dataset_map: dict) -> Any: """ - Evaluates a condition using the dataset filter, transforming URIs into valid variable names. + Safely evaluates a condition string using the provided dataset map. - :param condition_string: A string representing the logical condition to evaluate. + :param condition_string: A string representing the condition to evaluate. Example: "(dataset_custom_1 & dataset_custom_2) | dataset_custom_3". :type condition_string: str - :param datasets_filter: A list of dataset URIs to be evaluated in the condition. - :type datasets_filter: List[str] + :param dataset_map: A dictionary where keys are valid variable names (dataset aliases), + and values are Dataset objects. + :type dataset_map: dict + + :returns: The result of evaluating the condition. + :rtype: Any + """ + tree = ast.parse(condition_string, mode='eval') + evaluator = parsers.SafeEvalVisitor(dataset_map) + return evaluator.evaluate(tree) + + @staticmethod + def evaluate_condition_with_datasets(datasets: Union[List[str], str]) -> Any: + """ + Evaluates a condition using the dataset filter, transforming URIs into valid variable names. + + :param datasets: A list or string of dataset URIs to be evaluated in the condition. + :type datasets_filter: Union[List[str], str] :returns: The result of the logical condition evaluation with URIs replaced by valid variable names. :rtype: Any """ dataset_map = {} + datasets_filter = [] + condition_string = "" + if isinstance(datasets, str): + condition_string: str = datasets + datasets_filter: List[str] = utils.extract_dataset_names(datasets) + utils.extract_storage_names(datasets) + else: + datasets_filter: List[str] = datasets + for uri in datasets_filter: valid_variable_name = utils.make_valid_variable_name(uri) condition_string = condition_string.replace(uri, valid_variable_name) dataset_map[valid_variable_name] = Dataset(uri) - evaluated_condition = eval(condition_string, {}, dataset_map) + evaluated_condition = DagBuilder.safe_eval(condition_string, dataset_map) return evaluated_condition @staticmethod def process_file_with_datasets( - file: str, datasets_filter: List[str], condition_string: Optional[str] = None + file: str, datasets: Union[List[str], str] ) -> Any: """ Processes datasets from a file and evaluates conditions if provided. :param file: The file path containing dataset information in a YAML or other structured format. :type file: str - :param datasets_filter: A list of dataset names to filter and process. - :type datasets_filter: List[str] - :param condition_string: A logical condition string to evaluate using the datasets. - If not provided, the function returns a list of `Dataset` objects based on the file and filter. - Example: "(dataset_custom_1 & dataset_custom_2) | dataset_custom_3". - :type condition_string: Optional[str] + :param datasets: A list of dataset or string of dataset names to filter and process. + :type datasets_filter: Union[List[str], str] :returns: The result of the condition evaluation if `condition_string` is provided, otherwise a list of `Dataset` objects. :rtype: Any """ is_airflow_version_at_least_2_9 = version.parse(AIRFLOW_VERSION) >= version.parse("2.9.0") + dataset_map = {} + condition_string = "" + if isinstance(datasets, str): + condition_string: str = datasets + datasets_filter: List[str] = utils.extract_dataset_names(datasets) + utils.extract_storage_names(datasets) + else: + datasets_filter: List[str] = datasets + if condition_string and is_airflow_version_at_least_2_9: map_datasets = utils.get_datasets_map_uri_yaml_file(file, datasets_filter) dataset_map = {alias_dataset: Dataset(uri) for alias_dataset, uri in map_datasets.items()} - return eval(condition_string, {}, dataset_map) + evaluated_condition = DagBuilder.safe_eval(condition_string, dataset_map) + return evaluated_condition else: datasets_uri = utils.get_datasets_uri_yaml_file(file, datasets_filter) return [Dataset(uri) for uri in datasets_uri] @@ -684,19 +715,15 @@ def configure_schedule(dag_params: Dict[str, Any], dag_kwargs: Dict[str, Any]) - has_file_attr = utils.check_dict_key(schedule, "file") has_datasets_attr = utils.check_dict_key(schedule, "datasets") - has_conditions_attr = utils.check_dict_key(schedule, "conditions") if has_file_attr and has_datasets_attr: file = schedule.get("file") datasets_filter = schedule.get("datasets") - condition_string = schedule.get("conditions") - - dag_kwargs["schedule"] = DagBuilder.process_file_with_datasets(file, datasets_filter, condition_string) + dag_kwargs["schedule"] = DagBuilder.process_file_with_datasets(file, datasets_filter) - elif has_conditions_attr and has_datasets_attr and is_airflow_version_at_least_2_9: + elif has_datasets_attr and is_airflow_version_at_least_2_9: datasets_filter = schedule["datasets"] - condition_string = schedule["conditions"] - dag_kwargs["schedule"] = DagBuilder.evaluate_condition_with_datasets(condition_string, datasets_filter) + dag_kwargs["schedule"] = DagBuilder.evaluate_condition_with_datasets(datasets_filter) else: dag_kwargs["schedule"] = [Dataset(uri) for uri in schedule] @@ -705,8 +732,6 @@ def configure_schedule(dag_params: Dict[str, Any], dag_kwargs: Dict[str, Any]) - schedule.pop("file") if has_datasets_attr: schedule.pop("datasets") - if has_conditions_attr: - schedule.pop("conditions") # pylint: disable=too-many-locals def build(self) -> Dict[str, Union[str, DAG]]: diff --git a/dagfactory/dagfactory.py b/dagfactory/dagfactory.py index a0f2fc8e..0791edae 100644 --- a/dagfactory/dagfactory.py +++ b/dagfactory/dagfactory.py @@ -78,11 +78,20 @@ def _load_config(config_filepath: str) -> Dict[str, Any]: def __join(loader: yaml.FullLoader, node: yaml.Node) -> str: seq = loader.construct_sequence(node) return "".join([str(i) for i in seq]) + + def __or(loader: yaml.FullLoader, node: yaml.Node) -> str: + seq = loader.construct_sequence(node) + return " | ".join([f"({str(i)})" for i in seq]) + + def __and(loader: yaml.FullLoader, node: yaml.Node) -> str: + seq = loader.construct_sequence(node) + return " & ".join([f"({str(i)})" for i in seq]) yaml.add_constructor("!join", __join, yaml.FullLoader) + yaml.add_constructor("!or", __or, yaml.FullLoader) + yaml.add_constructor("!and", __and, yaml.FullLoader) with open(config_filepath, "r", encoding="utf-8") as fp: - yaml.add_constructor("!join", __join, yaml.FullLoader) config_with_env = os.path.expandvars(fp.read()) config: Dict[str, Any] = yaml.load(stream=config_with_env, Loader=yaml.FullLoader) except Exception as err: From 6958e24574eede540b53c595b936fe226f9ff138 Mon Sep 17 00:00:00 2001 From: ErickSeo Date: Thu, 2 Jan 2025 20:03:10 -0300 Subject: [PATCH 16/20] Lint --- dagfactory/dagbuilder.py | 6 ++---- dagfactory/dagfactory.py | 2 +- dagfactory/parsers.py | 8 ++++---- dagfactory/utils.py | 7 +++++-- 4 files changed, 12 insertions(+), 11 deletions(-) diff --git a/dagfactory/dagbuilder.py b/dagfactory/dagbuilder.py index e862a9e5..96a3c976 100644 --- a/dagfactory/dagbuilder.py +++ b/dagfactory/dagbuilder.py @@ -627,7 +627,7 @@ def safe_eval(condition_string: str, dataset_map: dict) -> Any: :returns: The result of evaluating the condition. :rtype: Any """ - tree = ast.parse(condition_string, mode='eval') + tree = ast.parse(condition_string, mode="eval") evaluator = parsers.SafeEvalVisitor(dataset_map) return evaluator.evaluate(tree) @@ -659,9 +659,7 @@ def evaluate_condition_with_datasets(datasets: Union[List[str], str]) -> Any: return evaluated_condition @staticmethod - def process_file_with_datasets( - file: str, datasets: Union[List[str], str] - ) -> Any: + def process_file_with_datasets(file: str, datasets: Union[List[str], str]) -> Any: """ Processes datasets from a file and evaluates conditions if provided. diff --git a/dagfactory/dagfactory.py b/dagfactory/dagfactory.py index 0791edae..f896c2c1 100644 --- a/dagfactory/dagfactory.py +++ b/dagfactory/dagfactory.py @@ -78,7 +78,7 @@ def _load_config(config_filepath: str) -> Dict[str, Any]: def __join(loader: yaml.FullLoader, node: yaml.Node) -> str: seq = loader.construct_sequence(node) return "".join([str(i) for i in seq]) - + def __or(loader: yaml.FullLoader, node: yaml.Node) -> str: seq = loader.construct_sequence(node) return " | ".join([f"({str(i)})" for i in seq]) diff --git a/dagfactory/parsers.py b/dagfactory/parsers.py index 642c57c3..7584425d 100644 --- a/dagfactory/parsers.py +++ b/dagfactory/parsers.py @@ -15,16 +15,16 @@ def visit_BinOp(self, node): left = self.visit(node.left) right = self.visit(node.right) - if isinstance(node.op, ast.BitAnd): + if isinstance(node.op, ast.BitAnd): return left & right - elif isinstance(node.op, ast.BitOr): + elif isinstance(node.op, ast.BitOr): return left | right else: raise ValueError(f"Unsupported binary operation: {type(node.op).__name__}") def visit_UnaryOp(self, node): operand = self.visit(node.operand) - if isinstance(node.op, ast.Not): + if isinstance(node.op, ast.Not): return ~operand else: raise ValueError(f"Unsupported unary operation: {type(node.op).__name__}") @@ -38,4 +38,4 @@ def visit_Constant(self, node): return node.value def generic_visit(self, node): - raise ValueError(f"Unsupported syntax: {type(node).__name__}") \ No newline at end of file + raise ValueError(f"Unsupported syntax: {type(node).__name__}") diff --git a/dagfactory/utils.py b/dagfactory/utils.py index c9e98ed2..e7e01fe5 100644 --- a/dagfactory/utils.py +++ b/dagfactory/utils.py @@ -299,15 +299,18 @@ def get_datasets_map_uri_yaml_file(file_path: str, datasets_filter: str) -> Dict logging.error("Error: File '%s' not found.", file_path) raise + def extract_dataset_names(expression) -> List[str]: - dataset_pattern = r'\b[a-zA-Z_][a-zA-Z0-9_]*\b' + dataset_pattern = r"\b[a-zA-Z_][a-zA-Z0-9_]*\b" datasets = re.findall(dataset_pattern, expression) return datasets + def extract_storage_names(expression) -> List[str]: - storage_pattern = r'[a-zA-Z][a-zA-Z0-9+.-]*://[a-zA-Z0-9\-_/\.]+' + storage_pattern = r"[a-zA-Z][a-zA-Z0-9+.-]*://[a-zA-Z0-9\-_/\.]+" storages = re.findall(storage_pattern, expression) return storages + def make_valid_variable_name(uri) -> str: return re.sub(r"\W|^(?=\d)", "_", uri) From 10150bbeaa1e941337f43d68e9b85f6ed32b7916 Mon Sep 17 00:00:00 2001 From: ErickSeo Date: Thu, 2 Jan 2025 20:05:42 -0300 Subject: [PATCH 17/20] Update __all__ in dagfactory and adjust expected output for invalid dataset names --- dagfactory/__init__.py | 2 +- tests/test_utils.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/dagfactory/__init__.py b/dagfactory/__init__.py index 1d5e3a5b..a51e5df4 100644 --- a/dagfactory/__init__.py +++ b/dagfactory/__init__.py @@ -5,5 +5,5 @@ __version__ = "0.21.0" __all__ = [ "DagFactory", - "load_yaml_dags" + "load_yaml_dags", ] diff --git a/tests/test_utils.py b/tests/test_utils.py index b7b50ce1..de5a1c5d 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -305,7 +305,7 @@ def test_extract_dataset_names(): assert result == expected expression = "123_invalid_dataset" - expected = ["invalid_dataset"] + expected = [] result = utils.extract_dataset_names(expression) assert result == expected From cca2b54176b1e6505e892af2529438da2015f9cd Mon Sep 17 00:00:00 2001 From: ErickSeo Date: Thu, 2 Jan 2025 20:08:52 -0300 Subject: [PATCH 18/20] Remove unsupported unary operation handling from SafeEvalVisitor --- dagfactory/parsers.py | 7 ------- tests/test_parsers.py | 7 ------- 2 files changed, 14 deletions(-) diff --git a/dagfactory/parsers.py b/dagfactory/parsers.py index 7584425d..4da385ae 100644 --- a/dagfactory/parsers.py +++ b/dagfactory/parsers.py @@ -22,13 +22,6 @@ def visit_BinOp(self, node): else: raise ValueError(f"Unsupported binary operation: {type(node.op).__name__}") - def visit_UnaryOp(self, node): - operand = self.visit(node.operand) - if isinstance(node.op, ast.Not): - return ~operand - else: - raise ValueError(f"Unsupported unary operation: {type(node.op).__name__}") - def visit_Name(self, node): if node.id in self.dataset_map: return self.dataset_map[node.id] diff --git a/tests/test_parsers.py b/tests/test_parsers.py index fe6af3e7..759ecfee 100644 --- a/tests/test_parsers.py +++ b/tests/test_parsers.py @@ -35,13 +35,6 @@ def test_visit_BinOp_or(visitor): expected = 1 | 3 assert result == expected -def test_visit_UnaryOp_not(visitor): - condition_string = "~dataset_custom_1" - tree = ast.parse(condition_string, mode='eval') - result = visitor.evaluate(tree) - expected = ~1 - assert result == expected - def test_visit_Name(visitor): condition_string = "dataset_custom_2" tree = ast.parse(condition_string, mode='eval') From cf7df8143ee16387a868c896a3560aa0d76015d4 Mon Sep 17 00:00:00 2001 From: ErickSeo Date: Thu, 2 Jan 2025 20:12:56 -0300 Subject: [PATCH 19/20] Fix formatting in __all__ declaration in dagfactory --- dagfactory/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dagfactory/__init__.py b/dagfactory/__init__.py index a51e5df4..b86179da 100644 --- a/dagfactory/__init__.py +++ b/dagfactory/__init__.py @@ -4,6 +4,6 @@ __version__ = "0.21.0" __all__ = [ - "DagFactory", + "DagFactory", "load_yaml_dags", ] From 4ecb2b6f5fde847e94475ec1ad0926d640fbaaab Mon Sep 17 00:00:00 2001 From: ErickSeo Date: Fri, 3 Jan 2025 09:48:18 -0300 Subject: [PATCH 20/20] Refactor dataset handling in DagBuilder to improve clarity and introduce utility for parsing dataset conditions --- dagfactory/dagbuilder.py | 75 ++++++++++++++++++++++------------------ dagfactory/utils.py | 6 ++++ 2 files changed, 47 insertions(+), 34 deletions(-) diff --git a/dagfactory/dagbuilder.py b/dagfactory/dagbuilder.py index 96a3c976..f8a758f7 100644 --- a/dagfactory/dagbuilder.py +++ b/dagfactory/dagbuilder.py @@ -11,7 +11,7 @@ from copy import deepcopy from datetime import datetime, timedelta from functools import partial -from typing import Any, Callable, Dict, List, Union +from typing import Any, Callable, Dict, List, Tuple, Union from airflow import DAG, configuration from airflow.models import BaseOperator, Variable @@ -632,61 +632,66 @@ def safe_eval(condition_string: str, dataset_map: dict) -> Any: return evaluator.evaluate(tree) @staticmethod - def evaluate_condition_with_datasets(datasets: Union[List[str], str]) -> Any: + def _extract_and_transform_datasets(datasets_conditions: str) -> Tuple[str, Dict[str, Any]]: """ - Evaluates a condition using the dataset filter, transforming URIs into valid variable names. + Extracts dataset names and storage paths from the conditions string and transforms them into valid variable names. - :param datasets: A list or string of dataset URIs to be evaluated in the condition. - :type datasets_filter: Union[List[str], str] + :param datasets_conditions: A string of conditions dataset URIs to be evaluated in the condition. + :type datasets_conditions: str - :returns: The result of the logical condition evaluation with URIs replaced by valid variable names. - :rtype: Any + :returns: A tuple containing the transformed conditions string and the dataset map. + :rtype: Tuple[str, Dict[str, Any]] """ dataset_map = {} - datasets_filter = [] - condition_string = "" - if isinstance(datasets, str): - condition_string: str = datasets - datasets_filter: List[str] = utils.extract_dataset_names(datasets) + utils.extract_storage_names(datasets) - else: - datasets_filter: List[str] = datasets + datasets_filter: List[str] = utils.extract_dataset_names(datasets_conditions) + utils.extract_storage_names( + datasets_conditions + ) for uri in datasets_filter: valid_variable_name = utils.make_valid_variable_name(uri) - condition_string = condition_string.replace(uri, valid_variable_name) + datasets_conditions = datasets_conditions.replace(uri, valid_variable_name) dataset_map[valid_variable_name] = Dataset(uri) - evaluated_condition = DagBuilder.safe_eval(condition_string, dataset_map) + + return datasets_conditions, dataset_map + + @staticmethod + def evaluate_condition_with_datasets(datasets_conditions: str) -> Any: + """ + Evaluates a condition using the dataset filter, transforming URIs into valid variable names. + + :param datasets_conditions: A string of conditions dataset URIs to be evaluated in the condition. + :type datasets_conditions: str + + :returns: The result of the logical condition evaluation with URIs replaced by valid variable names. + :rtype: Any + """ + datasets_conditions, dataset_map = DagBuilder._extract_and_transform_datasets(datasets_conditions) + evaluated_condition = DagBuilder.safe_eval(datasets_conditions, dataset_map) return evaluated_condition @staticmethod - def process_file_with_datasets(file: str, datasets: Union[List[str], str]) -> Any: + def process_file_with_datasets(file: str, datasets_conditions: str) -> Any: """ Processes datasets from a file and evaluates conditions if provided. :param file: The file path containing dataset information in a YAML or other structured format. :type file: str - :param datasets: A list of dataset or string of dataset names to filter and process. - :type datasets_filter: Union[List[str], str] + :param datasets_conditions: A string of dataset conditions to filter and process. + :type datasets_conditions: str :returns: The result of the condition evaluation if `condition_string` is provided, otherwise a list of `Dataset` objects. :rtype: Any """ is_airflow_version_at_least_2_9 = version.parse(AIRFLOW_VERSION) >= version.parse("2.9.0") - dataset_map = {} - condition_string = "" - if isinstance(datasets, str): - condition_string: str = datasets - datasets_filter: List[str] = utils.extract_dataset_names(datasets) + utils.extract_storage_names(datasets) - else: - datasets_filter: List[str] = datasets + datasets_conditions, dataset_map = DagBuilder._extract_and_transform_datasets(datasets_conditions) - if condition_string and is_airflow_version_at_least_2_9: - map_datasets = utils.get_datasets_map_uri_yaml_file(file, datasets_filter) + if is_airflow_version_at_least_2_9: + map_datasets = utils.get_datasets_map_uri_yaml_file(file, list(dataset_map.keys())) dataset_map = {alias_dataset: Dataset(uri) for alias_dataset, uri in map_datasets.items()} - evaluated_condition = DagBuilder.safe_eval(condition_string, dataset_map) + evaluated_condition = DagBuilder.safe_eval(datasets_conditions, dataset_map) return evaluated_condition else: - datasets_uri = utils.get_datasets_uri_yaml_file(file, datasets_filter) + datasets_uri = utils.get_datasets_uri_yaml_file(file, list(dataset_map.keys())) return [Dataset(uri) for uri in datasets_uri] @staticmethod @@ -716,12 +721,14 @@ def configure_schedule(dag_params: Dict[str, Any], dag_kwargs: Dict[str, Any]) - if has_file_attr and has_datasets_attr: file = schedule.get("file") - datasets_filter = schedule.get("datasets") - dag_kwargs["schedule"] = DagBuilder.process_file_with_datasets(file, datasets_filter) + datasets: Union[List[str], str] = schedule.get("datasets") + datasets_conditions: str = utils.parse_list_datasets(datasets) + dag_kwargs["schedule"] = DagBuilder.process_file_with_datasets(file, datasets_conditions) elif has_datasets_attr and is_airflow_version_at_least_2_9: - datasets_filter = schedule["datasets"] - dag_kwargs["schedule"] = DagBuilder.evaluate_condition_with_datasets(datasets_filter) + datasets = schedule["datasets"] + datasets_conditions: str = utils.parse_list_datasets(datasets) + dag_kwargs["schedule"] = DagBuilder.evaluate_condition_with_datasets(datasets_conditions) else: dag_kwargs["schedule"] = [Dataset(uri) for uri in schedule] diff --git a/dagfactory/utils.py b/dagfactory/utils.py index e7e01fe5..0d5eb08d 100644 --- a/dagfactory/utils.py +++ b/dagfactory/utils.py @@ -314,3 +314,9 @@ def extract_storage_names(expression) -> List[str]: def make_valid_variable_name(uri) -> str: return re.sub(r"\W|^(?=\d)", "_", uri) + + +def parse_list_datasets(datasets: Union[List[str], str]) -> str: + if isinstance(datasets, list): + datasets = " & ".join(datasets) + return datasets