diff --git a/dagfactory/dagbuilder.py b/dagfactory/dagbuilder.py index a38b9868..a788ef41 100644 --- a/dagfactory/dagbuilder.py +++ b/dagfactory/dagbuilder.py @@ -1006,20 +1006,24 @@ def adjust_general_task_params(task_params: dict(str, Any)): task_params[variable["attribute"]] = Variable.get(variable["variable"], default_var=None) del task_params["variables_as_arguments"] - if utils.check_dict_key(task_params, "outlets") and version.parse(AIRFLOW_VERSION) >= version.parse("2.4.0"): - if utils.check_dict_key(task_params["outlets"], "file") and utils.check_dict_key( - task_params["outlets"], "datasets" - ): - file = task_params["outlets"]["file"] - datasets_filter = task_params["outlets"]["datasets"] - datasets_uri = utils.get_datasets_uri_yaml_file(file, datasets_filter) - - del task_params["outlets"]["file"] - del task_params["outlets"]["datasets"] - else: - datasets_uri = task_params["outlets"] + if version.parse(AIRFLOW_VERSION) >= version.parse("2.4.0"): + print("task_params перед обработкой:", task_params) + for key in ["inlets", "outlets"]: + if utils.check_dict_key(task_params, key): + if utils.check_dict_key(task_params[key], "file") and utils.check_dict_key( + task_params[key], "datasets" + ): + file = task_params[key]["file"] + datasets_filter = task_params[key]["datasets"] + datasets_uri = utils.get_datasets_uri_yaml_file(file, datasets_filter) + + del task_params[key]["file"] + del task_params[key]["datasets"] + else: + datasets_uri = task_params[key] - task_params["outlets"] = [Dataset(uri) for uri in datasets_uri] + if key in task_params and datasets_uri: + task_params[key] = [Dataset(uri) for uri in datasets_uri] @staticmethod def make_decorator( diff --git a/dev/dags/datasets/example_dag_datasets.yml b/dev/dags/datasets/example_dag_datasets.yml index ec14def9..5d7231a1 100644 --- a/dev/dags/datasets/example_dag_datasets.yml +++ b/dev/dags/datasets/example_dag_datasets.yml @@ -17,11 +17,13 @@ example_simple_dataset_producer_dag: task_1: operator: airflow.operators.bash_operator.BashOperator bash_command: "echo 1" + inlets: [ 's3://bucket_example/raw/dataset1_source.json' ] outlets: ['s3://bucket_example/raw/dataset1.json'] task_2: operator: airflow.operators.bash_operator.BashOperator bash_command: "echo 2" dependencies: [task_1] + inlets: [ 's3://bucket_example/raw/dataset2_source.json' ] outlets: ['s3://bucket_example/raw/dataset2.json'] example_simple_dataset_consumer_dag: diff --git a/dev/dags/datasets/example_dag_datasets_outlet.yml b/dev/dags/datasets/example_dag_datasets_outlet_inlet.yml similarity index 86% rename from dev/dags/datasets/example_dag_datasets_outlet.yml rename to dev/dags/datasets/example_dag_datasets_outlet_inlet.yml index d76a08e8..709a2de1 100644 --- a/dev/dags/datasets/example_dag_datasets_outlet.yml +++ b/dev/dags/datasets/example_dag_datasets_outlet_inlet.yml @@ -9,10 +9,12 @@ producer_dag: task_1: operator: airflow.operators.bash_operator.BashOperator bash_command: "echo 1" + inlets: [ 's3://bucket_example/raw/dataset1_source.json' ] outlets: [ 's3://bucket_example/raw/dataset1.json' ] task_2: bash_command: "echo 2" dependencies: [ task_1 ] + inlets: [ 's3://bucket_example/raw/dataset2_source.json' ] outlets: [ 's3://bucket_example/raw/dataset2.json' ] consumer_dag: default_args: diff --git a/docs/features/datasets.md b/docs/features/datasets.md index 7092a5d1..6a5d8e7a 100644 --- a/docs/features/datasets.md +++ b/docs/features/datasets.md @@ -1,16 +1,18 @@ # Datasets DAG Factory supports Airflow’s [Datasets](https://airflow.apache.org/docs/apache-airflow/stable/authoring-and-scheduling/datasets.html). -## Datasets Outlets +## Datasets Outlets and Inlets -To leverage, you need to specify the `Dataset` in the `outlets` key in the configuration file. The `outlets` key is a list of strings that represent the dataset locations. -In the `schedule` key of the consumer dag, you can set the `Dataset` you would like to schedule against. The key is a list of strings that represent the dataset locations. -The consumer dag will run when all the datasets are available. +To leverage datasets, you need to specify the `Dataset` in the `outlets` and `inlets` keys in the configuration file. +The `outlets` and `inlets` keys should contain a list of strings representing dataset locations. +In the `schedule` key of the consumer DAG, you can set the `Dataset` that the DAG should be scheduled against. The key +should contain a list of dataset locations. +The consumer DAG will run when all the specified datasets become avai -#### Example: Outlet +#### Example: Outlet and Inlet -```title="example_dag_datasets_outlet.yml" ---8<-- "dev/dags/datasets/example_dag_datasets_outlet.yml" +```title="example_dag_datasets_outlet_inlet.yml" +--8<-- "dev/dags/datasets/example_dag_datasets_outlet_inlet.yml" ``` ![datasets_example.png](../static/images/datasets/outlets/datasets_example.png "Simple Dataset Producer") diff --git a/tests/test_dagbuilder.py b/tests/test_dagbuilder.py index 67ba0b3c..5850af2e 100644 --- a/tests/test_dagbuilder.py +++ b/tests/test_dagbuilder.py @@ -982,34 +982,61 @@ def test_replace_expand_string_with_xcom(): assert updated_task_conf_output["expand"]["key_1"] == XComArg(tasks_dict["task_1"]) assert updated_task_conf_xcomarg["expand"]["key_1"] == XComArg(tasks_dict["task_1"]) - @pytest.mark.skipif( version.parse(AIRFLOW_VERSION) <= version.parse("2.4.0"), reason="Requires Airflow version greater than 2.4.0" ) @pytest.mark.parametrize( - "outlets,output", + "inlets, outlets, expected_inlets, expected_outlets", [ + # 1️⃣ Test: inlets are provided, but outlets are None + ( + {"datasets": "s3://test/in.txt", "file": "file://path/to/in_file.txt"}, + None, # No `outlets` + ["s3://test/in.txt", "file://path/to/in_file.txt"], + [], + ), + # 2️⃣ Test: both inlets and outlets are provided + ( + ["s3://test/in.txt"], + ["s3://test/out.txt"], + ["s3://test/in.txt"], + ["s3://test/out.txt"], + ), + # 3️⃣ Test: inlets are None, but outlets are provided ( - {"datasets": "s3://test/test.txt", "file": "file://path/to/my_file.txt"}, - ["s3://test/test.txt", "file://path/to/my_file.txt"], + None, # No `inlets` + ["s3://test/out.txt"], # `outlets` exist + [], + ["s3://test/out.txt"], ), - (["s3://test/test.txt"], ["s3://test/test.txt"]), ], ) + @patch("dagfactory.dagbuilder.utils.get_datasets_uri_yaml_file", new_callable=mock_open) -def test_make_task_outlets(mock_read_file, outlets, output): +def test_make_task_inlets_outlets(mock_read_file, inlets, outlets, expected_inlets, expected_outlets): + """Tests if the `make_task()` function correctly handles `inlets` and `outlets` parameters.""" + + # Create a DagBuilder instance td = dagbuilder.DagBuilder("test_dag", DAG_CONFIG, DEFAULT_CONFIG) + + # Define task parameters task_params = { "task_id": "process", "python_callable_name": "expand_task", "python_callable_file": os.path.realpath(__file__), + "inlets": inlets, "outlets": outlets, } - mock_read_file.return_value = output + + # Mock the response of `get_datasets_uri_yaml_file` to return expected values + mock_read_file.return_value = expected_inlets + expected_outlets + operator = "airflow.operators.python_operator.PythonOperator" actual = td.make_task(operator, task_params) - assert actual.outlets == [Dataset(uri) for uri in output] + # Assertions to check if the actual results match the expected values + assert actual.inlets == [Dataset(uri) for uri in expected_inlets] + assert actual.outlets == [Dataset(uri) for uri in expected_outlets] @patch("dagfactory.dagbuilder.TaskGroup", new=MockTaskGroup) def test_make_nested_task_groups():