Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 20 additions & 63 deletions tests/test_dagbuilder.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
import os
import datetime
import os
from pathlib import Path
from unittest.mock import patch

import pendulum
import pytest
from airflow import DAG

from packaging import version

try:
Expand Down Expand Up @@ -36,12 +35,6 @@

from dagfactory import dagbuilder

if version.parse(AIRFLOW_VERSION) >= version.parse("2.0.0"):
from airflow.timetables.interval import CronDataIntervalTimetable
else:
Timetable = None
# pylint: disable=ungrouped-imports,invalid-name

if version.parse(AIRFLOW_VERSION) >= version.parse("2.3.0"):
from airflow.models import MappedOperator
else:
Expand Down Expand Up @@ -461,7 +454,7 @@ def test_build():
assert actual["dag"].tags == ["tag1", "tag2"]


def test_get_dag_params():
def test_get_dag_params_dag_with_task_group():
td = dagbuilder.DagBuilder("test_dag", DAG_CONFIG_TASK_GROUP, DEFAULT_CONFIG)
expected = {
"default_args": {
Expand Down Expand Up @@ -534,21 +527,13 @@ def test_build_task_groups():
td.build()
else:
actual = td.build()
task_group_1 = {
t for t in actual["dag"].task_dict if t.startswith("task_group_1")
}
task_group_2 = {
t for t in actual["dag"].task_dict if t.startswith("task_group_2")
}
task_group_1 = {t for t in actual["dag"].task_dict if t.startswith("task_group_1")}
task_group_2 = {t for t in actual["dag"].task_dict if t.startswith("task_group_2")}
assert actual["dag_id"] == "test_dag"
assert isinstance(actual["dag"], DAG)
assert len(actual["dag"].tasks) == 6
assert actual["dag"].task_dict["task_1"].downstream_task_ids == {
"task_group_1.task_2"
}
assert actual["dag"].task_dict["task_group_1.task_2"].downstream_task_ids == {
"task_group_1.task_3"
}
assert actual["dag"].task_dict["task_1"].downstream_task_ids == {"task_group_1.task_2"}
assert actual["dag"].task_dict["task_group_1.task_2"].downstream_task_ids == {"task_group_1.task_3"}
assert actual["dag"].task_dict["task_group_1.task_3"].downstream_task_ids == {
"task_4",
"task_group_2.task_5",
Expand All @@ -569,9 +554,7 @@ def test_make_task_groups():
}
dag = "dag"
task_groups = dagbuilder.DagBuilder.make_task_groups(task_group_dict, dag)
expected = MockTaskGroup(
tooltip="this is a task group", group_id="task_group", dag=dag
)
expected = MockTaskGroup(tooltip="this is a task group", group_id="task_group", dag=dag)
if version.parse(AIRFLOW_VERSION) < version.parse("2.0.0"):
assert task_groups == {}
else:
Expand Down Expand Up @@ -628,30 +611,22 @@ def test_make_dag_with_callback():
def test_get_dag_params_with_template_searchpath():
from dagfactory import utils

td = dagbuilder.DagBuilder(
"test_dag", {"template_searchpath": ["./sql"]}, DEFAULT_CONFIG
)
td = dagbuilder.DagBuilder("test_dag", {"template_searchpath": ["./sql"]}, DEFAULT_CONFIG)
error_message = "template_searchpath must be absolute paths"
with pytest.raises(Exception, match=error_message):
td.get_dag_params()

td = dagbuilder.DagBuilder(
"test_dag", {"template_searchpath": ["/sql"]}, DEFAULT_CONFIG
)
td = dagbuilder.DagBuilder("test_dag", {"template_searchpath": ["/sql"]}, DEFAULT_CONFIG)
error_message = "template_searchpath must be existing paths"
with pytest.raises(Exception, match=error_message):
td.get_dag_params()

td = dagbuilder.DagBuilder(
"test_dag", {"template_searchpath": "./sql"}, DEFAULT_CONFIG
)
td = dagbuilder.DagBuilder("test_dag", {"template_searchpath": "./sql"}, DEFAULT_CONFIG)
error_message = "template_searchpath must be absolute paths"
with pytest.raises(Exception, match=error_message):
td.get_dag_params()

td = dagbuilder.DagBuilder(
"test_dag", {"template_searchpath": "/sql"}, DEFAULT_CONFIG
)
td = dagbuilder.DagBuilder("test_dag", {"template_searchpath": "/sql"}, DEFAULT_CONFIG)
error_message = "template_searchpath must be existing paths"
with pytest.raises(Exception, match=error_message):
td.get_dag_params()
Expand All @@ -662,26 +637,20 @@ def test_get_dag_params_with_template_searchpath():


def test_get_dag_params_with_render_template_as_native_obj():
td = dagbuilder.DagBuilder(
"test_dag", {"render_template_as_native_obj": "true"}, DEFAULT_CONFIG
)
td = dagbuilder.DagBuilder("test_dag", {"render_template_as_native_obj": "true"}, DEFAULT_CONFIG)
error_message = "render_template_as_native_obj should be bool type!"
with pytest.raises(Exception, match=error_message):
td.get_dag_params()

false = lambda x: print(x)
td = dagbuilder.DagBuilder(
"test_dag", {"render_template_as_native_obj": false}, DEFAULT_CONFIG
)
td = dagbuilder.DagBuilder("test_dag", {"render_template_as_native_obj": false}, DEFAULT_CONFIG)
error_message = "render_template_as_native_obj should be bool type!"
with pytest.raises(Exception, match=error_message):
td.get_dag_params()


def test_make_task_with_duplicated_partial_kwargs():
td = dagbuilder.DagBuilder(
"test_dag", DAG_CONFIG_DYNAMIC_TASK_MAPPING, DEFAULT_CONFIG
)
td = dagbuilder.DagBuilder("test_dag", DAG_CONFIG_DYNAMIC_TASK_MAPPING, DEFAULT_CONFIG)
operator = "airflow.operators.bash_operator.BashOperator"
task_params = {
"task_id": "task_bash",
Expand All @@ -693,9 +662,7 @@ def test_make_task_with_duplicated_partial_kwargs():


def test_dynamic_task_mapping():
td = dagbuilder.DagBuilder(
"test_dag", DAG_CONFIG_DYNAMIC_TASK_MAPPING, DEFAULT_CONFIG
)
td = dagbuilder.DagBuilder("test_dag", DAG_CONFIG_DYNAMIC_TASK_MAPPING, DEFAULT_CONFIG)
if version.parse(AIRFLOW_VERSION) < version.parse("2.3.0"):
error_message = "Dynamic task mapping available only in Airflow >= 2.3.0"
with pytest.raises(Exception, match=error_message):
Expand All @@ -715,9 +682,7 @@ def test_dynamic_task_mapping():

@patch("dagfactory.dagbuilder.PythonOperator", new=MockPythonOperator)
def test_replace_expand_string_with_xcom():
td = dagbuilder.DagBuilder(
"test_dag", DAG_CONFIG_DYNAMIC_TASK_MAPPING, DEFAULT_CONFIG
)
td = dagbuilder.DagBuilder("test_dag", DAG_CONFIG_DYNAMIC_TASK_MAPPING, DEFAULT_CONFIG)
if version.parse(AIRFLOW_VERSION) < version.parse("2.3.0"):
with pytest.raises(Exception):
td.build()
Expand All @@ -727,15 +692,7 @@ def test_replace_expand_string_with_xcom():
task_conf_output = {"expand": {"key_1": "task_1.output"}}
task_conf_xcomarg = {"expand": {"key_1": "XcomArg(task_1)"}}
tasks_dict = {"task_1": MockPythonOperator()}
updated_task_conf_output = dagbuilder.DagBuilder.replace_expand_values(
task_conf_output, tasks_dict
)
updated_task_conf_xcomarg = dagbuilder.DagBuilder.replace_expand_values(
task_conf_xcomarg, tasks_dict
)
assert updated_task_conf_output["expand"]["key_1"] == XComArg(
tasks_dict["task_1"]
)
assert updated_task_conf_xcomarg["expand"]["key_1"] == XComArg(
tasks_dict["task_1"]
)
updated_task_conf_output = dagbuilder.DagBuilder.replace_expand_values(task_conf_output, tasks_dict)
updated_task_conf_xcomarg = dagbuilder.DagBuilder.replace_expand_values(task_conf_xcomarg, tasks_dict)
assert updated_task_conf_output["expand"]["key_1"] == XComArg(tasks_dict["task_1"])
assert updated_task_conf_xcomarg["expand"]["key_1"] == XComArg(tasks_dict["task_1"])