Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion dagfactory/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from .dagfactory import DagFactory, load_yaml_dags

__version__ = "0.23.0a3"
__version__ = "0.23.0a4"
__all__ = [
"DagFactory",
"load_yaml_dags",
Expand Down
223 changes: 106 additions & 117 deletions dagfactory/dagbuilder.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,28 @@

try:
from airflow.version import version as AIRFLOW_VERSION
except ImportError:
except ImportError: # pragma: no cover
from airflow import __version__ as AIRFLOW_VERSION


try:
from airflow.providers.cncf.kubernetes import get_provider_info

try:
K8S_PROVIDER_VERSION = get_provider_info.get_provider_info()["versions"][0]
except KeyError: # pragma: no cover
from airflow.providers.cncf.kubernetes import __version__

K8S_PROVIDER_VERSION = __version__
except ImportError: # pragma: no cover
K8S_PROVIDER_VERSION = "0"

INSTALLED_AIRFLOW_VERSION = version.parse(AIRFLOW_VERSION)

# python operators were moved in 2.4
try:
from airflow.operators.python import BranchPythonOperator, PythonOperator
except ImportError:
except ImportError: # pragma: no cover
from airflow.operators.python_operator import BranchPythonOperator, PythonOperator

from airflow.providers.http.sensors.http import HttpSensor
Expand All @@ -42,67 +54,50 @@
from airflow.providers.http.operators.http import HttpOperator

HTTP_OPERATOR_CLASS = HttpOperator
except ImportError:
except ImportError: # pragma: no cover
try:
from airflow.providers.http.operators.http import SimpleHttpOperator

HTTP_OPERATOR_CLASS = SimpleHttpOperator
except ImportError:
except ImportError: # pragma: no cover
# Fall back to dynamically importing the operator
HTTP_OPERATOR_CLASS = None


# sql sensor was moved in 2.4
try:
from airflow.sensors.sql_sensor import SqlSensor
except ImportError:
except ImportError: # pragma: no cover
from airflow.providers.common.sql.sensors.sql import SqlSensor

from airflow.sensors.python import PythonSensor

if INSTALLED_AIRFLOW_VERSION.major < AIRFLOW3_MAJOR_VERSION:
# k8s libraries are moved in v5.0.0
try:
from airflow.providers.cncf.kubernetes import get_provider_info

K8S_PROVIDER_VERSION = get_provider_info.get_provider_info()["versions"][0]
except ImportError:
K8S_PROVIDER_VERSION = "0"
from airflow.models import MappedOperator

# kubernetes operator
try:
if version.parse(K8S_PROVIDER_VERSION) < version.parse("5.0.0"):
from airflow.kubernetes.pod import Port
from airflow.kubernetes.pod_runtime_info_env import PodRuntimeInfoEnv
from airflow.kubernetes.volume import Volume
from airflow.kubernetes.volume_mount import VolumeMount
else:
from kubernetes.client.models import (
V1ContainerPort as Port,
V1EnvVar,
V1EnvVarSource,
V1ObjectFieldSelector,
V1Volume,
V1VolumeMount as VolumeMount,
)
from airflow.kubernetes.secret import Secret
try:
from airflow.providers.cncf.kubernetes.operators.pod import KubernetesPodOperator
except ImportError:
from airflow.providers.cncf.kubernetes.operators.kubernetes_pod import KubernetesPodOperator

if version.parse(K8S_PROVIDER_VERSION) < version.parse("10"):
from airflow.providers.cncf.kubernetes.operators.kubernetes_pod import KubernetesPodOperator
else:
from airflow.providers.cncf.kubernetes.operators.pod import KubernetesPodOperator
except ImportError: # pragma: no cover
from airflow.contrib.kubernetes.pod import Port
from airflow.contrib.kubernetes.pod_runtime_info_env import PodRuntimeInfoEnv
from airflow.contrib.kubernetes.secret import Secret
from airflow.contrib.kubernetes.volume import Volume
from airflow.contrib.kubernetes.volume_mount import VolumeMount
from airflow.contrib.operators.kubernetes_pod_operator import KubernetesPodOperator
try:
from airflow.providers.cncf.kubernetes.secret import Secret
except ImportError:
from airflow.kubernetes.secret import Secret

from airflow.models import MappedOperator
from airflow.sensors.python import PythonSensor
from airflow.timetables.base import Timetable
from airflow.utils.task_group import TaskGroup
from kubernetes.client.models import V1Container, V1Pod
from kubernetes.client.models import (
V1Affinity,
V1Container,
V1ContainerPort as Port,
V1EnvFromSource,
V1EnvVar,
V1LocalObjectReference,
V1Pod,
V1PodSecurityContext,
V1Toleration,
V1Volume,
V1VolumeMount as VolumeMount,
)

from dagfactory import parsers, utils
from dagfactory.exceptions import DagFactoryConfigException, DagFactoryException
Expand Down Expand Up @@ -267,10 +262,72 @@ def make_timetable(timetable: str, timetable_params: Dict[str, Any]) -> Timetabl
raise DagFactoryException(f"Failed to import timetable {timetable} due to: {err}") from err
try:
schedule: Timetable = timetable_obj(**timetable_params)
except Exception as err:
except Exception as err: # pragma: no cover
Comment thread
tatiana marked this conversation as resolved.
raise DagFactoryException(f"Failed to create {timetable_obj} due to: {err}") from err
return schedule

@staticmethod
def _create_volume(vol):
volume = V1Volume(name=vol.get("name"))
for k, v in vol["configs"].items():
snake_key = utils.convert_to_snake_case(k)
if hasattr(volume, snake_key):
setattr(volume, snake_key, v)
else:
raise DagFactoryException(f"Volume for KubernetesPodOperator does not have attribute {k}")
return volume

@staticmethod
def _clean_kpo_task_params(task_params: dict) -> dict:
conversions = [
("ports", Port, "list"),
("volume_mounts", VolumeMount, "list"),
("env_vars", V1EnvVar, "list"),
("env_from", V1EnvFromSource, "list"),
("secrets", Secret, "list"),
("affinity", V1Affinity, "single"),
("image_pull_secrets", V1LocalObjectReference, "list"),
("tolerations", V1Toleration, "list"),
("security_context", V1PodSecurityContext, "single"),
("init_containers", V1Container, "list"),
("pod_runtime_info_envs", V1EnvVar, "list"),
("full_pod_spec", V1Pod, "single"),
]

# Conditional field based on version
if version.parse(K8S_PROVIDER_VERSION) >= version.parse("7.8.0"):
from kubernetes.client.models import V1HostAlias

conversions.append(("host_aliases", V1HostAlias, "list"))

if version.parse(K8S_PROVIDER_VERSION) >= version.parse("7.0.0"):
from kubernetes.client.models import V1PodDNSConfig

conversions.append(("dns_config", V1PodDNSConfig, "single"))

if version.parse(K8S_PROVIDER_VERSION) >= version.parse("5.0.0"):
from kubernetes.client.models import V1ResourceRequirements

conversions.append(("container_resources", V1ResourceRequirements, "single"))

if version.parse(K8S_PROVIDER_VERSION) >= version.parse("4.4.0"):
from kubernetes.client.models import V1SecurityContext

conversions.append(("container_security_context", V1SecurityContext, "single"))

for key, cls, conv_type in conversions:
if key in task_params and task_params[key] is not None:
if conv_type == "list":
task_params[key] = [cls(**v) for v in task_params[key]]
elif conv_type == "single":
task_params[key] = cls(task_params[key])

# Special case for volumes that uses a different constructor
if task_params.get("volumes") is not None:
task_params["volumes"] = [DagBuilder._create_volume(vol) for vol in task_params["volumes"]]

return task_params

# pylint: disable=too-many-branches
# pylint: disable=too-many-statements
# pylint: disable=too-many-locals
Expand Down Expand Up @@ -369,76 +426,8 @@ def make_task(operator: str, task_params: Dict[str, Any]) -> BaseOperator:
# Airflow 2.0 doesn't allow these to be passed to operator
del task_params["response_check_lambda"]

if INSTALLED_AIRFLOW_VERSION.major < AIRFLOW3_MAJOR_VERSION:
# KubernetesPodOperator
if issubclass(operator_obj, KubernetesPodOperator):
task_params["secrets"] = (
[Secret(**v) for v in task_params.get("secrets")]
if task_params.get("secrets") is not None
else None
)

task_params["ports"] = (
[Port(**v) for v in task_params.get("ports")] if task_params.get("ports") is not None else None
)
task_params["volume_mounts"] = (
[VolumeMount(**v) for v in task_params.get("volume_mounts")]
if task_params.get("volume_mounts") is not None
else None
)
if version.parse(K8S_PROVIDER_VERSION) < version.parse("5.0.0"):
task_params["volumes"] = (
[Volume(**v) for v in task_params.get("volumes")]
if task_params.get("volumes") is not None
else None
)
task_params["pod_runtime_info_envs"] = (
[PodRuntimeInfoEnv(**v) for v in task_params.get("pod_runtime_info_envs")]
if task_params.get("pod_runtime_info_envs") is not None
else None
)
else:
if task_params.get("volumes") is not None:
task_params_volumes = []
for vol in task_params.get("volumes"):
resp = V1Volume(name=vol.get("name"))
for k, v in vol["configs"].items():
snake_key = utils.convert_to_snake_case(k)
if hasattr(resp, snake_key):
setattr(resp, snake_key, v)
else:
raise DagFactoryException(
f"Volume for KubernetesPodOperator \
does not have attribute {k}"
)
task_params_volumes.append(resp)
task_params["volumes"] = task_params_volumes
else:
task_params["volumes"] = None

task_params["pod_runtime_info_envs"] = (
[
V1EnvVar(
name=v.get("name"),
value_from=V1EnvVarSource(
field_ref=V1ObjectFieldSelector(field_path=v.get("field_path"))
),
)
for v in task_params.get("pod_runtime_info_envs")
]
if task_params.get("pod_runtime_info_envs") is not None
else None
)
task_params["full_pod_spec"] = (
V1Pod(**task_params.get("full_pod_spec"))
if task_params.get("full_pod_spec") is not None
else None
)
task_params["init_containers"] = (
[V1Container(**v) for v in task_params.get("init_containers")]
if task_params.get("init_containers") is not None
else None
)
if issubclass(operator_obj, KubernetesPodOperator):
task_params = DagBuilder._clean_kpo_task_params(task_params)

# HttpOperator
if HTTP_OPERATOR_CLASS and issubclass(operator_obj, HTTP_OPERATOR_CLASS):
Expand Down Expand Up @@ -469,7 +458,7 @@ def make_task(operator: str, task_params: Dict[str, Any]) -> BaseOperator:
else operator_obj.partial(**task_params).expand(**expand_kwargs)
)
except Exception as err:
raise DagFactoryException(f"Failed to create {operator_obj} task") from err
raise DagFactoryException(f"Failed to create {operator_obj} task: {err}") from err
return task

@staticmethod
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ classifiers = [
dependencies = [
"apache-airflow>=2.3",
"apache-airflow-providers-http>=2.0.0",
"apache-airflow-providers-cncf-kubernetes<10.4.2", # https://github.com/astronomer/dag-factory/issues/397
"apache-airflow-providers-cncf-kubernetes",
"pyyaml",
"packaging",
]
Expand Down
8 changes: 4 additions & 4 deletions tests/fixtures/dag_factory_kubernetes_pod_operator.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ default:
example_dag:
tasks:
task_1:
operator: airflow.contrib.operators.kubernetes_pod_operator.KubernetesPodOperator
operator: airflow.providers.cncf.kubernetes.operators.pod.KubernetesPodOperator
namespace: 'default'
config_file : 'path_to_config_file'
image : 'image'
Expand All @@ -35,8 +35,8 @@ example_dag:
{"name":"name","configs":{'persistentVolumeClaim': {'claimName': 'test-volume'}}},
]
pod_runtime_info_envs : [
{"name":"name","field_path":"field_path"},
{"name":"name","field_path":"field_path"},
{"name":"name","value":"field_path"},
{"name":"name","value":"field_path"},
]
full_pod_spec : {
"api_version": "api_version",
Expand All @@ -55,7 +55,7 @@ example_dag:
in_cluster: False
dependencies: []
task_2:
operator: airflow.contrib.operators.kubernetes_pod_operator.KubernetesPodOperator
operator: airflow.providers.cncf.kubernetes.operators.pod.KubernetesPodOperator
namespace: 'default'
config_file : 'path_to_config_file'
image : 'image'
Expand Down
Loading