diff --git a/airflow/models/param.py b/airflow/models/param.py index d944ef20311d3..2f93f7dd88341 100644 --- a/airflow/models/param.py +++ b/airflow/models/param.py @@ -21,7 +21,7 @@ import json import logging import warnings -from typing import TYPE_CHECKING, Any, ItemsView, Iterable, MutableMapping, ValuesView +from typing import TYPE_CHECKING, Any, ClassVar, ItemsView, Iterable, MutableMapping, ValuesView from airflow.exceptions import AirflowException, ParamValidationError, RemovedInAirflow3Warning from airflow.utils.context import Context @@ -47,6 +47,8 @@ class Param: default & description will form the schema """ + __version__: ClassVar[int] = 1 + CLASS_IDENTIFIER = "__class" def __init__(self, default: Any = NOTSET, description: str | None = None, **kwargs): @@ -112,6 +114,16 @@ def dump(self) -> dict: def has_value(self) -> bool: return self.value is not NOTSET + def serialize(self) -> dict: + return {"value": self.value, "description": self.description, "schema": self.schema} + + @staticmethod + def deserialize(data: dict[str, Any], version: int) -> Param: + if version > Param.__version__: + raise TypeError("serialized version > class version") + + return Param(default=data["value"], description=data["description"], schema=data["schema"]) + class ParamsDict(MutableMapping[str, Any]): """ @@ -120,6 +132,7 @@ class ParamsDict(MutableMapping[str, Any]): dictionary implicitly and ideally not needed to be used directly. """ + __version__: ClassVar[int] = 1 __slots__ = ["__dict", "suppress_exception"] def __init__(self, dict_obj: dict | None = None, suppress_exception: bool = False): @@ -231,6 +244,16 @@ def validate(self) -> dict[str, Any]: return resolved_dict + def serialize(self) -> dict[str, Any]: + return self.dump() + + @staticmethod + def deserialize(data: dict, version: int) -> ParamsDict: + if version > ParamsDict.__version__: + raise TypeError("serialized version > class version") + + return ParamsDict(data) + class DagParam(ResolveMixin): """DAG run parameter reference. diff --git a/airflow/plugins_manager.py b/airflow/plugins_manager.py index 1d887110208a0..d5fed0d3a4297 100644 --- a/airflow/plugins_manager.py +++ b/airflow/plugins_manager.py @@ -38,7 +38,7 @@ from airflow import settings from airflow.utils.entry_points import entry_points_with_dist from airflow.utils.file import find_path_from_directory -from airflow.utils.module_loading import as_importable_string +from airflow.utils.module_loading import qualname if TYPE_CHECKING: from airflow.hooks.base import BaseHook @@ -373,7 +373,7 @@ def initialize_ti_deps_plugins(): for plugin in plugins: registered_ti_dep_classes.update( - {as_importable_string(ti_dep.__class__): ti_dep.__class__ for ti_dep in plugin.ti_deps} + {qualname(ti_dep.__class__): ti_dep.__class__ for ti_dep in plugin.ti_deps} ) @@ -406,7 +406,7 @@ def initialize_extra_operators_links_plugins(): operator_extra_links.extend(list(plugin.operator_extra_links)) registered_operator_link_classes.update( - {as_importable_string(link.__class__): link.__class__ for link in plugin.operator_extra_links} + {qualname(link.__class__): link.__class__ for link in plugin.operator_extra_links} ) @@ -425,7 +425,7 @@ def initialize_timetables_plugins(): log.debug("Initialize extra timetables plugins") timetable_classes = { - as_importable_string(timetable_class): timetable_class + qualname(timetable_class): timetable_class for plugin in plugins for timetable_class in plugin.timetables } @@ -525,25 +525,20 @@ def get_plugin_info(attrs_to_dump: Iterable[str] | None = None) -> list[dict[str info: dict[str, Any] = {"name": plugin.name} for attr in attrs_to_dump: if attr in ("global_operator_extra_links", "operator_extra_links"): - info[attr] = [ - f"<{as_importable_string(d.__class__)} object>" for d in getattr(plugin, attr) - ] + info[attr] = [f"<{qualname(d.__class__)} object>" for d in getattr(plugin, attr)] elif attr in ("macros", "timetables", "hooks", "executors"): - info[attr] = [as_importable_string(d) for d in getattr(plugin, attr)] + info[attr] = [qualname(d) for d in getattr(plugin, attr)] elif attr == "listeners": # listeners are always modules info[attr] = [d.__name__ for d in getattr(plugin, attr)] elif attr == "appbuilder_views": info[attr] = [ - {**d, "view": as_importable_string(d["view"].__class__) if "view" in d else None} + {**d, "view": qualname(d["view"].__class__) if "view" in d else None} for d in getattr(plugin, attr) ] elif attr == "flask_blueprints": info[attr] = [ - ( - f"<{as_importable_string(d.__class__)}: " - f"name={d.name!r} import_name={d.import_name!r}>" - ) + f"<{qualname(d.__class__)}: name={d.name!r} import_name={d.import_name!r}>" for d in getattr(plugin, attr) ] else: diff --git a/airflow/serialization/serde.py b/airflow/serialization/serde.py new file mode 100644 index 0000000000000..401edaaa79917 --- /dev/null +++ b/airflow/serialization/serde.py @@ -0,0 +1,306 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import dataclasses +import enum +import logging +import re +import sys +from importlib import import_module +from types import ModuleType +from typing import Any, TypeVar, Union + +import attr + +import airflow.serialization.serializers +from airflow.configuration import conf +from airflow.utils.module_loading import import_string, iter_namespace, qualname + +log = logging.getLogger(__name__) + +MAX_RECURSION_DEPTH = sys.getrecursionlimit() - 1 + +CLASSNAME = "__classname__" +VERSION = "__version__" +DATA = "__data__" +SCHEMA_ID = "__id__" +CACHE = "__cache__" + +OLD_TYPE = "__type" +OLD_SOURCE = "__source" +OLD_DATA = "__var" + +DEFAULT_VERSION = 0 + +T = TypeVar("T", bool, float, int, dict, list, str, tuple, set) +U = Union[bool, float, int, dict, list, str, tuple, set] +S = Union[list, tuple, set] + +_serializers: dict[str, ModuleType] = {} +_deserializers: dict[str, ModuleType] = {} +_extra_allowed: set[str] = set() + +_primitives = (int, bool, float, str) +_iterables = (list, set, tuple) +_patterns: list[re.Pattern] = [] + +_reverse_cache: dict[int, tuple[ModuleType, str, int]] = {} + + +def encode(cls: str, version: int, data: T) -> dict[str, str | int | T]: + """Encodes o so it can be understood by the deserializer""" + return {CLASSNAME: cls, VERSION: version, DATA: data} + + +def decode(d: dict[str, str | int | T]) -> tuple: + return d[CLASSNAME], d[VERSION], d.get(DATA, None) + + +def serialize(o: object, depth: int = 0) -> U | None: + """ + Recursively serializes objects into a primitive. Primitives (int, float, int, bool) + are returned as is. Tuples and dicts are iterated over, where it is assumed that keys + for dicts can be represented as str. Values that are not primitives are serialized if + a serializer is found for them. The order in which serializers are used + is 1) a serialize function provided by the object 2) a registered serializer in + the namespace of airflow.serialization.serializers and 3) an attr or dataclass annotations. + If a serializer cannot be found a TypeError is raised. + + :param o: object to serialize + :param depth: private + :return: a primitive + """ + if depth == MAX_RECURSION_DEPTH: + raise RecursionError("maximum recursion depth reached for serialization") + + # None remains None + if o is None: + return o + + # primitive types are returned as is + if isinstance(o, _primitives): + if isinstance(o, enum.Enum): + return o.value + + return o + + # tuples and plain dicts are iterated over recursively + if isinstance(o, _iterables): + s = [serialize(d, depth + 1) for d in o] + if isinstance(o, tuple): + return tuple(s) + if isinstance(o, set): + return set(s) + return s + + if isinstance(o, dict): + if CLASSNAME in o or SCHEMA_ID in o: + raise AttributeError(f"reserved key {CLASSNAME} or {SCHEMA_ID} found in dict to serialize") + + return {str(k): serialize(v, depth + 1) for k, v in o.items()} + + cls = type(o) + qn = qualname(o) + + # custom serializers + dct = { + CLASSNAME: qn, + VERSION: getattr(cls, "__version__", DEFAULT_VERSION), + } + + # if there is a builtin serializer available use that + if qn in _serializers: + data, classname, version, is_serialized = _serializers[qn].serialize(o) + if is_serialized: + return encode(classname, version, serialize(data, depth + 1)) + + # object / class brings their own + if hasattr(o, "serialize"): + data = getattr(o, "serialize")() + + # if we end up with a structure, ensure its values are serialized + if isinstance(data, dict): + data = serialize(data, depth + 1) + + dct[DATA] = data + return dct + + # dataclasses + if dataclasses.is_dataclass(cls): + data = dataclasses.asdict(o) + dct[DATA] = serialize(data, depth + 1) + return dct + + # attr annotated + if attr.has(cls): + # Only include attributes which we can pass back to the classes constructor + data = attr.asdict(o, recurse=True, filter=lambda a, v: a.init) # type: ignore[arg-type] + dct[DATA] = serialize(data, depth + 1) + return dct + + raise TypeError(f"cannot serialize object of type {cls}") + + +def deserialize(o: T | None, full=True, type_hint: Any = None) -> object: + """ + Deserializes an object of primitive type T into an object. Uses an allow + list to determine if a class can be loaded. + + :param o: primitive to deserialize into an arbitrary object. + :param full: if False it will return a stringified representation + of an object and will not load any classes + :param type_hint: if set it will be used to help determine what + object to deserialize in. It does not override if another + specification is found + :return: object + """ + if o is None: + return o + + if isinstance(o, _primitives): + return o + + if isinstance(o, _iterables): + return [deserialize(d) for d in o] + + if not isinstance(o, dict): + raise TypeError() + + o = _convert(o) + + # plain dict and no type hint + if CLASSNAME not in o and not type_hint or VERSION not in o: + return {str(k): deserialize(v, full) for k, v in o.items()} + + # custom deserialization starts here + cls: Any + version = 0 + value: Any + classname: str + + if type_hint: + cls = type_hint + classname = qualname(cls) + version = 0 # type hinting always sets version to 0 + value = o + + if CLASSNAME in o and VERSION in o: + classname, version, value = decode(o) + if not _match(classname) and classname not in _extra_allowed: + raise ImportError( + f"{classname} was not found in allow list for deserialization imports." + f"To allow it, add it to allowed_deserialization_classes in the configuration" + ) + + if full: + cls = import_string(classname) + + # only return string representation + if not full: + return _stringify(classname, version, value) + + # registered deserializer + if classname in _deserializers: + return _deserializers[classname].deserialize(classname, version, deserialize(value)) + + # class has deserialization function + if hasattr(cls, "deserialize"): + return getattr(cls, "deserialize")(deserialize(value), version) + + # attr or dataclass + if attr.has(cls) or dataclasses.is_dataclass(cls): + class_version = getattr(cls, "__version__", 0) + if int(version) > class_version: + raise TypeError( + "serialized version of %s is newer than module version (%s > %s)", + classname, + version, + class_version, + ) + + return cls(**deserialize(value)) + + # no deserializer available + raise TypeError(f"No deserializer found for {classname}") + + +def _convert(old: dict) -> dict: + """Converts an old style serialization to new style""" + if OLD_TYPE in old and OLD_DATA in old: + return {CLASSNAME: old[OLD_TYPE], VERSION: DEFAULT_VERSION, DATA: old[OLD_DATA][OLD_DATA]} + + return old + + +def _match(classname: str) -> bool: + for p in _patterns: + if p.match(classname): + return True + + return False + + +def _stringify(classname: str, version: int, value: T | None) -> str: + s = f"{classname}@version={version}(" + if isinstance(value, _primitives): + s += f"{value})" + elif isinstance(value, _iterables): + s += ",".join(str(serialize(value, False))) + elif isinstance(value, dict): + for k, v in value.items(): + s += f"{k}={str(serialize(v, False))}," + s = s[:-1] + ")" + + return s + + +def _register(): + """Register builtin serializers and deserializers for types that don't have any themselves""" + _serializers.clear() + _deserializers.clear() + + for _, name, _ in iter_namespace(airflow.serialization.serializers): + name = import_module(name) + for s in getattr(name, "serializers", list()): + if not isinstance(s, str): + s = qualname(s) + if s in _serializers and _serializers[s] != name: + raise AttributeError(f"duplicate {s} for serialization in {name} and {_serializers[s]}") + log.debug("registering %s for serialization") + _serializers[s] = name + for d in getattr(name, "deserializers", list()): + if not isinstance(d, str): + d = qualname(d) + if d in _deserializers and _deserializers[d] != name: + raise AttributeError(f"duplicate {d} for deserialization in {name} and {_serializers[d]}") + log.debug("registering %s for deserialization", d) + _deserializers[d] = name + _extra_allowed.add(d) + + +def _compile_patterns(): + patterns = conf.get("core", "allowed_deserialization_classes").split() + + _patterns.clear() # ensure to reinit + for p in patterns: + _patterns.append(re.compile(p)) + + +_register() +_compile_patterns() diff --git a/airflow/serialization/serialized_objects.py b/airflow/serialization/serialized_objects.py index 4563e570dddd6..d6340c4b512a9 100644 --- a/airflow/serialization/serialized_objects.py +++ b/airflow/serialization/serialized_objects.py @@ -55,7 +55,7 @@ from airflow.timetables.base import Timetable from airflow.utils.code_utils import get_python_source from airflow.utils.docs import get_docs_url -from airflow.utils.module_loading import as_importable_string, import_string +from airflow.utils.module_loading import import_string, qualname from airflow.utils.operator_resources import Resources from airflow.utils.task_group import MappedTaskGroup, TaskGroup @@ -182,7 +182,7 @@ def _encode_timetable(var: Timetable) -> dict[str, Any]: can be completely controlled by a custom subclass. """ timetable_class = type(var) - importable_string = as_importable_string(timetable_class) + importable_string = qualname(timetable_class) if _get_registered_timetable(importable_string) is None: raise _TimetableNotRegistered(importable_string) return {Encoding.TYPE: importable_string, Encoding.VAR: var.serialize()} @@ -1005,19 +1005,19 @@ def _deserialize_deps(cls, deps: list[str]) -> set[BaseTIDep]: raise AirflowException("Can not load plugins") instances = set() - for qualname in set(deps): + for qn in set(deps): if ( - not qualname.startswith("airflow.ti_deps.deps.") - and qualname not in plugins_manager.registered_ti_dep_classes + not qn.startswith("airflow.ti_deps.deps.") + and qn not in plugins_manager.registered_ti_dep_classes ): raise SerializationError( - f"Custom dep class {qualname} not deserialized, please register it through plugins." + f"Custom dep class {qn} not deserialized, please register it through plugins." ) try: - instances.add(import_string(qualname)()) + instances.add(import_string(qn)()) except ImportError: - log.warning("Error importing dep %r", qualname, exc_info=True) + log.warning("Error importing dep %r", qn, exc_info=True) return instances @classmethod @@ -1108,6 +1108,15 @@ def _serialize_operator_extra_links(cls, operator_extra_links: Iterable[BaseOper return serialize_operator_extra_links + @classmethod + def serialize(cls, var: Any, *, strict: bool = False) -> Any: + # the wonders of multiple inheritance BaseOperator defines an instance method + return BaseSerialization.serialize(var=var, strict=strict) + + @classmethod + def deserialize(cls, encoded_var: Any) -> Any: + return BaseSerialization.deserialize(encoded_var=encoded_var) + class SerializedDAG(DAG, BaseSerialization): """ @@ -1159,6 +1168,7 @@ def serialize_dag(cls, dag: DAG) -> dict: del serialized_dag["timetable"] serialized_dag["tasks"] = [cls.serialize(task) for _, task in dag.task_dict.items()] + dag_deps = { dep for task in dag.task_dict.values() diff --git a/airflow/serialization/serializers/__init__.py b/airflow/serialization/serializers/__init__.py new file mode 100644 index 0000000000000..217e5db960782 --- /dev/null +++ b/airflow/serialization/serializers/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/airflow/serialization/serializers/bignum.py b/airflow/serialization/serializers/bignum.py new file mode 100644 index 0000000000000..972963467d8cb --- /dev/null +++ b/airflow/serialization/serializers/bignum.py @@ -0,0 +1,55 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from decimal import Decimal +from typing import TYPE_CHECKING + +from airflow.utils.module_loading import qualname + +if TYPE_CHECKING: + from airflow.serialization.serde import U + + +serializers = [Decimal] +deserializers = serializers + +__version__ = 1 + + +def serialize(o: object) -> tuple[U, str, int, bool]: + if isinstance(o, Decimal): + name = qualname(o) + _, _, exponent = o.as_tuple() + if exponent >= 0: # No digits after the decimal point. + return int(o), name, __version__, True + # Technically lossy due to floating point errors, but the best we + # can do without implementing a custom encode function. + return float(o), name, __version__, True + + return "", "", 0, False + + +def deserialize(classname: str, version: int, data: object) -> Decimal: + if version > __version__: + raise TypeError(f"serialized {version} of {classname} > {__version__}") + + if classname != qualname(Decimal): + raise TypeError(f"{classname} != {qualname(Decimal)}") + + return Decimal(str(data)) diff --git a/airflow/serialization/serializers/datetime.py b/airflow/serialization/serializers/datetime.py new file mode 100644 index 0000000000000..b400258838a0d --- /dev/null +++ b/airflow/serialization/serializers/datetime.py @@ -0,0 +1,73 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from datetime import date, datetime, timedelta +from typing import TYPE_CHECKING + +from pendulum import DateTime +from pendulum.tz import timezone + +from airflow.utils.module_loading import qualname +from airflow.utils.timezone import convert_to_utc, is_naive + +if TYPE_CHECKING: + from airflow.serialization.serde import U + +__version__ = 1 + +serializers = [date, datetime, timedelta, DateTime] +deserializers = serializers + +TIMESTAMP = "timestamp" +TIMEZONE = "tz" + + +def serialize(o: object) -> tuple[U, str, int, bool]: + if isinstance(o, DateTime) or isinstance(o, datetime): + qn = qualname(o) + if is_naive(o): + o = convert_to_utc(o) + + tz = o.tzname() + + return {TIMESTAMP: o.timestamp(), TIMEZONE: tz}, qn, __version__, True + + if isinstance(o, date): + return o.isoformat(), qualname(o), __version__, True + + if isinstance(o, timedelta): + return o.total_seconds(), qualname(o), __version__, True + + return "", "", 0, False + + +def deserialize(classname: str, version: int, data: dict | str) -> datetime | timedelta | date: + if classname == qualname(datetime) and isinstance(data, dict): + return datetime.fromtimestamp(float(data[TIMESTAMP]), tz=timezone(data[TIMEZONE])) + + if classname == qualname(DateTime) and isinstance(data, dict): + return DateTime.fromtimestamp(float(data[TIMESTAMP]), tz=timezone(data[TIMEZONE])) + + if classname == qualname(timedelta) and isinstance(data, (str, float)): + return timedelta(seconds=float(data)) + + if classname == qualname(date) and isinstance(data, str): + return date.fromisoformat(data) + + raise TypeError(f"unknown date/time format {classname}") diff --git a/airflow/serialization/serializers/kubernetes.py b/airflow/serialization/serializers/kubernetes.py new file mode 100644 index 0000000000000..a8f0c0f333de5 --- /dev/null +++ b/airflow/serialization/serializers/kubernetes.py @@ -0,0 +1,68 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING + +from airflow.utils.module_loading import qualname + +serializers = [] + +try: + from kubernetes.client import models as k8s + + serializers = [k8s.v1_pod.V1Pod, k8s.V1ResourceRequirements] +except ImportError: + k8s = None + +if TYPE_CHECKING: + from airflow.serialization.serde import U + + +__version__ = 1 + +deserializers: list[type[object]] = [] +log = logging.getLogger(__name__) + + +def serialize(o: object) -> tuple[U, str, int, bool]: + if not k8s: + return "", "", 0, False + + if isinstance(o, (k8s.V1Pod, k8s.V1ResourceRequirements)): + from airflow.kubernetes.pod_generator import PodGenerator + + def safe_get_name(pod): + """ + We're running this in an except block, so we don't want it to + fail under any circumstances, e.g. by accessing an attribute that isn't there + """ + try: + return pod.metadata.name + except Exception: + return None + + try: + return PodGenerator.serialize_pod(o), qualname(o), __version__, True + except Exception: + log.warning("Serialization failed for pod %s", safe_get_name(o)) + log.debug("traceback for serialization error", exc_info=True) + return "", "", 0, False + + return "", "", 0, False diff --git a/airflow/serialization/serializers/numpy.py b/airflow/serialization/serializers/numpy.py new file mode 100644 index 0000000000000..4dea70aa13742 --- /dev/null +++ b/airflow/serialization/serializers/numpy.py @@ -0,0 +1,104 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +from airflow.utils.module_loading import qualname + +serializers = [] + +try: + import numpy as np + + serializers = [ + np.int_, + np.intc, + np.intp, + np.int8, + np.int16, + np.int32, + np.int64, + np.uint8, + np.uint16, + np.uint32, + np.uint64, + np.bool_, + np.float_, + np.float16, + np.float64, + np.complex_, + np.complex64, + np.complex128, + ] +except ImportError: + np = None # type: ignore + + +if TYPE_CHECKING: + from airflow.serialization.serde import U + +deserializers: list = serializers +_deserializers: dict[str, type[object]] = {qualname(x): x for x in deserializers} + +__version__ = 1 + + +def serialize(o: object) -> tuple[U, str, int, bool]: + if np is None: + return "", "", 0, False + + name = qualname(o) + if isinstance( + o, + ( + np.int_, + np.intc, + np.intp, + np.int8, + np.int16, + np.int32, + np.int64, + np.uint8, + np.uint16, + np.uint32, + np.uint64, + ), + ): + return int(o), name, __version__, True + + if isinstance(o, np.bool_): + return bool(np), name, __version__, True + + if isinstance( + o, (np.float_, np.float16, np.float32, np.float64, np.complex_, np.complex64, np.complex128) + ): + return float(o), name, __version__, True + + return "", "", 0, False + + +def deserialize(classname: str, version: int, data: str) -> Any: + if version > __version__: + raise TypeError("serialized version is newer than class version") + + f = _deserializers.get(classname, None) + if callable(f): + return f(data) # type: ignore [call-arg] + + raise TypeError(f"unsupported {classname} found for numpy deserialization") diff --git a/airflow/serialization/serializers/timezone.py b/airflow/serialization/serializers/timezone.py new file mode 100644 index 0000000000000..0a0bcc222c611 --- /dev/null +++ b/airflow/serialization/serializers/timezone.py @@ -0,0 +1,69 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from typing import TYPE_CHECKING + +import pendulum +from pendulum.tz.timezone import FixedTimezone, Timezone + +from airflow.utils.module_loading import qualname + +if TYPE_CHECKING: + from airflow.serialization.serde import U + + +serializers = [FixedTimezone, Timezone] +deserializers = serializers + +__version__ = 1 + + +def serialize(o: object) -> tuple[U, str, int, bool]: + """Encode a Pendulum Timezone for serialization. + + Airflow only supports timezone objects that implements Pendulum's Timezone + interface. We try to keep as much information as possible to make conversion + round-tripping possible (see ``decode_timezone``). We need to special-case + UTC; Pendulum implements it as a FixedTimezone (i.e. it gets encoded as + 0 without the special case), but passing 0 into ``pendulum.timezone`` does + not give us UTC (but ``+00:00``). + """ + name = qualname(o) + if isinstance(o, FixedTimezone): + if o.offset == 0: + return "UTC", name, __version__, True + return o.offset, name, __version__, True + + if isinstance(o, Timezone): + return o.name, name, __version__, True + + return "", "", 0, False + + +def deserialize(classname: str, version: int, data: object) -> Timezone: + if not isinstance(data, (str, int)): + raise TypeError(f"{data} is not of type int or str but of {type(data)}") + + if version > __version__: + raise TypeError(f"serialized {version} of {classname} > {__version__}") + + if isinstance(data, int): + return pendulum.tz.fixed_timezone(data) + + return pendulum.tz.timezone(data) diff --git a/airflow/utils/json.py b/airflow/utils/json.py index 16f36cb475f58..0a497aab898e2 100644 --- a/airflow/utils/json.py +++ b/airflow/utils/json.py @@ -17,118 +17,19 @@ # under the License. from __future__ import annotations -import dataclasses import json -import logging -import re from datetime import date, datetime from decimal import Decimal from typing import Any -import attr from flask.json.provider import JSONProvider -from airflow.configuration import conf -from airflow.serialization.enums import Encoding -from airflow.utils.module_loading import import_string +from airflow.serialization.serde import CLASSNAME, DATA, SCHEMA_ID, deserialize, serialize from airflow.utils.timezone import convert_to_utc, is_naive -try: - import numpy as np -except ImportError: - np = None # type: ignore - -try: - from kubernetes.client import models as k8s -except ImportError: - k8s = None - -# Dates and JSON encoding/decoding - -log = logging.getLogger(__name__) - -CLASSNAME = "__classname__" -VERSION = "__version__" -DATA = "__data__" - -OLD_TYPE = "__type" -OLD_SOURCE = "__source" -OLD_DATA = "__var" - -DEFAULT_VERSION = 0 - - -class AirflowJsonEncoder(json.JSONEncoder): - """Custom Airflow json encoder implementation.""" - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.default = self._default - - @staticmethod - def _default(obj): - """Convert dates and numpy objects in a json serializable format.""" - if isinstance(obj, datetime): - if is_naive(obj): - obj = convert_to_utc(obj) - return obj.isoformat() - elif isinstance(obj, date): - return obj.strftime("%Y-%m-%d") - elif isinstance(obj, Decimal): - _, _, exponent = obj.as_tuple() - if exponent >= 0: # No digits after the decimal point. - return int(obj) - # Technically lossy due to floating point errors, but the best we - # can do without implementing a custom encode function. - return float(obj) - elif np is not None and isinstance( - obj, - ( - np.int_, - np.intc, - np.intp, - np.int8, - np.int16, - np.int32, - np.int64, - np.uint8, - np.uint16, - np.uint32, - np.uint64, - ), - ): - return int(obj) - elif np is not None and isinstance(obj, np.bool_): - return bool(obj) - elif np is not None and isinstance( - obj, (np.float_, np.float16, np.float32, np.float64, np.complex_, np.complex64, np.complex128) - ): - return float(obj) - elif k8s is not None and isinstance(obj, (k8s.V1Pod, k8s.V1ResourceRequirements)): - from airflow.kubernetes.pod_generator import PodGenerator - - def safe_get_name(pod): - """ - We're running this in an except block, so we don't want it to - fail under any circumstances, e.g. by accessing an attribute that isn't there - """ - try: - return pod.metadata.name - except Exception: - return None - - try: - return PodGenerator.serialize_pod(obj) - except Exception: - log.warning("JSON encoding failed for pod %s", safe_get_name(obj)) - log.debug("traceback for pod JSON encode error", exc_info=True) - return {} - - raise TypeError(f"Object of type '{obj.__class__.__qualname__}' is not JSON serializable") - class AirflowJsonProvider(JSONProvider): - """JSON Provider for Flask app to use AirflowJsonEncoder.""" + """JSON Provider for Flask app to use WebEncoder.""" ensure_ascii: bool = True sort_keys: bool = True @@ -136,41 +37,55 @@ class AirflowJsonProvider(JSONProvider): def dumps(self, obj, **kwargs): kwargs.setdefault("ensure_ascii", self.ensure_ascii) kwargs.setdefault("sort_keys", self.sort_keys) - return json.dumps(obj, **kwargs, cls=AirflowJsonEncoder) + return json.dumps(obj, **kwargs, cls=WebEncoder) def loads(self, s: str | bytes, **kwargs): return json.loads(s, **kwargs) -# for now separate as AirflowJsonEncoder is non-standard -class XComEncoder(json.JSONEncoder): - """This encoder serializes any object that has attr, dataclass or a custom serializer.""" +class WebEncoder(json.JSONEncoder): + """This encodes values into a web understandable format. There is no deserializer""" + + def default(self, o: Any) -> Any: + if isinstance(o, datetime): + if is_naive(o): + o = convert_to_utc(o) + return o.isoformat() + + if isinstance(o, date): + return o.strftime("%Y-%m-%d") - def default(self, o: object) -> dict: - from airflow.serialization.serialized_objects import BaseSerialization + if isinstance(o, Decimal): + data = serialize(o) + if isinstance(data, dict) and DATA in data: + return data[DATA] - dct = { - CLASSNAME: o.__module__ + "." + o.__class__.__qualname__, - VERSION: getattr(o.__class__, "version", DEFAULT_VERSION), - } + try: + data = serialize(o) + if isinstance(data, dict) and CLASSNAME in data: + # this is here for backwards compatibility + if ( + data[CLASSNAME].startswith("numpy") + or data[CLASSNAME] == "kubernetes.client.models.v1_pod.V1Pod" + ): + return data[DATA] + return data + except TypeError: + raise - if hasattr(o, "serialize"): - dct[DATA] = getattr(o, "serialize")() - return dct - elif dataclasses.is_dataclass(o.__class__): - data = dataclasses.asdict(o) - dct[DATA] = BaseSerialization.serialize(data) - return dct - elif attr.has(o.__class__): - # Only include attributes which we can pass back to the classes constructor - data = attr.asdict(o, recurse=True, filter=lambda a, v: a.init) # type: ignore[arg-type] - dct[DATA] = BaseSerialization.serialize(data) - return dct - else: + +class XComEncoder(json.JSONEncoder): + """This encoder serializes any object that has attr, dataclass or a custom serializer.""" + + def default(self, o: object) -> Any: + try: + return serialize(o) + except TypeError: return super().default(o) def encode(self, o: Any) -> str: - if isinstance(o, dict) and CLASSNAME in o: + # checked here and in serialize + if isinstance(o, dict) and (CLASSNAME in o or SCHEMA_ID in o): raise AttributeError(f"reserved key {CLASSNAME} found in dict to serialize") return super().encode(o) @@ -183,86 +98,20 @@ class XComDecoder(json.JSONDecoder): as is. """ - _pattern: list[re.Pattern] = [] - def __init__(self, *args, **kwargs) -> None: if not kwargs.get("object_hook"): kwargs["object_hook"] = self.object_hook - patterns = conf.get("core", "allowed_deserialization_classes").split() - - self._pattern.clear() # ensure to reinit - for p in patterns: - self._pattern.append(re.compile(p)) - super().__init__(*args, **kwargs) def object_hook(self, dct: dict) -> object: - dct = XComDecoder._convert(dct) - - if CLASSNAME in dct and VERSION in dct: - from airflow.serialization.serialized_objects import BaseSerialization - - classname = dct[CLASSNAME] - cls = None - - for p in self._pattern: - if p.match(classname): - cls = import_string(classname) - break - - if not cls: - raise ImportError(f"{classname} was not found in allow list for import") - - if hasattr(cls, "deserialize"): - return getattr(cls, "deserialize")(dct[DATA], dct[VERSION]) - - version = getattr(cls, "version", 0) - if int(dct[VERSION]) > version: - raise TypeError( - "serialized version of %s is newer than module version (%s > %s)", - dct[CLASSNAME], - dct[VERSION], - version, - ) - - if not attr.has(cls) and not dataclasses.is_dataclass(cls): - raise TypeError( - f"cannot deserialize: no deserialization method " - f"for {dct[CLASSNAME]} and not attr/dataclass decorated" - ) - - return cls(**BaseSerialization.deserialize(dct[DATA])) - - return dct + return deserialize(dct) @staticmethod def orm_object_hook(dct: dict) -> object: """Creates a readable representation of a serialized object""" - dct = XComDecoder._convert(dct) - if CLASSNAME in dct and VERSION in dct: - from airflow.serialization.serialized_objects import BaseSerialization - - if Encoding.VAR in dct[DATA] and Encoding.TYPE in dct[DATA]: - data = BaseSerialization.deserialize(dct[DATA]) - if not isinstance(data, dict): - raise TypeError(f"deserialized value should be a dict, but is {type(data)}") - else: - # custom serializer - data = dct[DATA] - - s = f"{dct[CLASSNAME]}@version={dct[VERSION]}(" - for k, v in data.items(): - s += f"{k}={v}," - s = s[:-1] + ")" - return s + return deserialize(dct, False) - return dct - - @staticmethod - def _convert(old: dict) -> dict: - """Converts an old style serialization to new style""" - if OLD_TYPE in old and OLD_SOURCE in old: - return {CLASSNAME: old[OLD_TYPE], VERSION: DEFAULT_VERSION, DATA: old[OLD_DATA]} - return old +# backwards compatibility +AirflowJsonEncoder = WebEncoder diff --git a/airflow/utils/module_loading.py b/airflow/utils/module_loading.py index dc361f9f95f89..3053a2ed0445f 100644 --- a/airflow/utils/module_loading.py +++ b/airflow/utils/module_loading.py @@ -17,10 +17,13 @@ # under the License. from __future__ import annotations +import pkgutil from importlib import import_module +from types import ModuleType +from typing import Callable -def import_string(dotted_path): +def import_string(dotted_path: str): """ Import a dotted module path and return the attribute/class designated by the last name in the path. Raise ImportError if the import failed. @@ -38,6 +41,24 @@ def import_string(dotted_path): raise ImportError(f'Module "{module_path}" does not define a "{class_name}" attribute/class') -def as_importable_string(thing) -> str: - """Convert an attribute/class to a string importable by ``import_string``.""" - return f"{thing.__module__}.{thing.__name__}" +def qualname(o: object | Callable) -> str: + """Convert an attribute/class/function to a string importable by ``import_string``.""" + if callable(o): + return f"{o.__module__}.{o.__name__}" + + cls = o + + if not isinstance(cls, type): # instance or class + cls = type(cls) + + name = cls.__qualname__ + module = cls.__module__ + + if module and module != "__builtin__": + return f"{module}.{name}" + + return name + + +def iter_namespace(ns: ModuleType): + return pkgutil.iter_modules(ns.__path__, ns.__name__ + ".") diff --git a/airflow/www/views.py b/airflow/www/views.py index 54800266bebd5..16b4818d7dbdb 100644 --- a/airflow/www/views.py +++ b/airflow/www/views.py @@ -3739,7 +3739,7 @@ def datasets_summary(self): data = {"datasets": datasets, "total_entries": count_query.scalar()} return ( - htmlsafe_json_dumps(data, separators=(",", ":"), cls=utils_json.AirflowJsonEncoder), + htmlsafe_json_dumps(data, separators=(",", ":"), cls=utils_json.WebEncoder), {"Content-Type": "application/json; charset=utf-8"}, ) diff --git a/docs/apache-airflow/concepts/index.rst b/docs/apache-airflow/concepts/index.rst index 6663fc004b333..61704dacc1afa 100644 --- a/docs/apache-airflow/concepts/index.rst +++ b/docs/apache-airflow/concepts/index.rst @@ -48,6 +48,7 @@ Here you can find detailed documentation about each one of Airflow's core concep timetable priority-weight cluster-policies + serializers **Communication** diff --git a/docs/apache-airflow/concepts/serializers.rst b/docs/apache-airflow/concepts/serializers.rst new file mode 100644 index 0000000000000..80aa8219ee39f --- /dev/null +++ b/docs/apache-airflow/concepts/serializers.rst @@ -0,0 +1,128 @@ + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + +Serialization +============= + +To support data exchange, like arguments, between tasks, Airflow needs to serialize the data to be exchanged and +deserialize it again when required in a downstream task. Serialization also happens so that the webserver and +the scheduler (as opposed to the dag processor) do no need to read the DAG file. This is done for security purposes +and efficiency. + +Serialization is a surprisingly hard job. Python out of the box only has support for serialization of primitives, +like ``str`` and ``int`` and it loops over iterables. When things become more complex, custom serialization is required. + +Airflow out of the box supports three ways of custom serialization. Primitives are are returned as is, without +additional encoding, e.g. a ``str`` remains a ``str``. When it is not a primitive (or iterable thereof) Airflow +looks for a registered serializer and deserializer in the namespace of ``airflow.serialization.serializers``. +If not found it will look in the class for a ``serialize()`` method or in case of deserialization a +``deserialize(data, version: int)`` method. Finally, if the class is either decorated with ``@dataclass`` +or ``@attr.define`` it will use the public methods for those decorators. + +If you are looking to extend Airflow with a new serializer, it is good to know when to choose what way of serialization. +Objects that are under the control of Airflow, i.e. residing under the namespace of ``airflow.*`` like +``airflow.model.dag.DAG`` or under control of the developer e.g. ``my.company.Foo`` should first be examined to see +whether they can be decorated with ``@attr.define`` or ``@dataclass``. If that is not possible then the ``serialize`` +and ``deserialize`` methods should be implemented. The ``serialize`` method should return a primitive or a dict. +It does not need to serialize the values in the dict, that will be taken care of, but the keys should be of a primitive +form. + +Objects that are not under control of Airflow, e.g. ``numpy.int16`` will need a registered serializer and deserializer. +Versioning is required. Primitives can be returned as can dicts. Again ``dict`` values do not need to be serialized, +but its keys need to be of primitive form. In case you are implementing a registered serializer, take special care +not to have circular imports. Typically, this can be avoided by using ``str`` for populating the list of serializers. +Like so: ``serializers = ["my.company.Foo"]`` instead of ``serializers = [Foo]``. + +:: + + Note: Serialization and deserialization is dependent on speed. Use built-in functions like ``dict`` as much as you can and stay away from using classes and other complex structures. + +Airflow Object +-------------- + +.. code-block:: python + + from typing import Any, ClassVar + + + class Foo: + __version__: ClassVar[int] = 1 + + def __init__(self, a, v) -> None: + self.a = a + self.b = {"x": v} + + def serialize(self) -> dict[str, Any]: + return { + "a": self.a, + "b": self.b, + } + + @staticmethod + def deserialize(data: dict[str, Any], version: int): + f = Foo(a=data["a"]) + f.b = data["b"] + return f + + +Registered +^^^^^^^^^^ + +.. code-block:: python + + from __future__ import annotations + + from decimal import Decimal + from typing import TYPE_CHECKING + + from airflow.utils.module_loading import qualname + + if TYPE_CHECKING: + from airflow.serialization.serde import U + + + serializers = [ + Decimal + ] # this can be a type or a fully qualified str. Str can be used to prevent circular imports + deserializers = serializers # in some cases you might not have a deserializer (e.g. k8s pod) + + __version__ = 1 # required + + # the serializer expects output, classname, version, is_serialized? + def serialize(o: object) -> tuple[U, str, int, bool]: + if isinstance(o, Decimal): + name = qualname(o) + _, _, exponent = o.as_tuple() + if exponent >= 0: # No digits after the decimal point. + return int(o), name, __version__, True + # Technically lossy due to floating point errors, but the best we + # can do without implementing a custom encode function. + return float(o), name, __version__, True + + return "", "", 0, False + + + # the deserializer sanitizes the data for you, so you do not need to deserialize values yourself + def deserialize(classname: str, version: int, data: object) -> Decimal: + # always check version compatibility + if version > __version__: + raise TypeError(f"serialized {version} of {classname} > {__version__}") + + if classname != qualname(Decimal): + raise TypeError(f"{classname} != {qualname(Decimal)}") + + return Decimal(str(data)) diff --git a/docs/apache-airflow/concepts/taskflow.rst b/docs/apache-airflow/concepts/taskflow.rst index 7efa38b8430c7..8cc87ee09bb55 100644 --- a/docs/apache-airflow/concepts/taskflow.rst +++ b/docs/apache-airflow/concepts/taskflow.rst @@ -156,7 +156,7 @@ yourself. To do so add the ``serialize()`` method to your class and the staticme class MyCustom: - version: ClassVar[int] = 1 + __version__: ClassVar[int] = 1 def __init__(self, x): self.x = x @@ -174,13 +174,13 @@ Object Versioning ^^^^^^^^^^^^^^^^^ It is good practice to version the objects that will be used in serialization. To do this add -``version: ClassVar[int] = `` to your class. Airflow assumes that your classes are backwards compatible, +``__version__: ClassVar[int] = `` to your class. Airflow assumes that your classes are backwards compatible, so that a version 2 is able to deserialize a version 1. In case you need custom logic for deserialization ensure that ``deserialize(data: dict, version: int)`` is specified. :: - Note: Typing of ``version`` is required and needs to be ``ClassVar[int]`` + Note: Typing of ``__version__`` is required and needs to be ``ClassVar[int]`` History ------- diff --git a/docs/apache-airflow/integration.rst b/docs/apache-airflow/integration.rst index a628f95dec6a6..e77f72ea5af3c 100644 --- a/docs/apache-airflow/integration.rst +++ b/docs/apache-airflow/integration.rst @@ -32,6 +32,7 @@ Airflow has a mechanism that allows you to expand its functionality and integrat * :doc:`Secrets backends ` * :doc:`Tracking systems ` * :doc:`Web UI Authentication backends ` +* :doc:`Serialization ` It also has integration with :doc:`Sentry ` service for error tracking. Other applications can also integrate using the :doc:`REST API `. diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index 5b186ae821fb5..e05912d2748a1 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -411,6 +411,7 @@ Deserialize deserialize Deserialized deserialized +deserializer deserializing dest dev @@ -1278,6 +1279,8 @@ serde serialise serializable SerializedDAG +serializer +serializers serverless ServiceAccount setattr diff --git a/tests/serialization/serializers/__init__.py b/tests/serialization/serializers/__init__.py new file mode 100644 index 0000000000000..217e5db960782 --- /dev/null +++ b/tests/serialization/serializers/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/tests/serialization/serializers/test_serializers.py b/tests/serialization/serializers/test_serializers.py new file mode 100644 index 0000000000000..e4f146b9d9502 --- /dev/null +++ b/tests/serialization/serializers/test_serializers.py @@ -0,0 +1,93 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import datetime +import decimal + +import numpy +import pendulum.tz +import pytest +from pendulum import DateTime + +from airflow.models.param import Param, ParamsDict +from airflow.serialization.serde import DATA, deserialize, serialize + + +class TestSerializers: + def test_datetime(self): + i = datetime.datetime(2022, 7, 10, 22, 10, 43, microsecond=0, tzinfo=pendulum.tz.UTC) + + s = serialize(i) + d = deserialize(s) + assert i.timestamp() == d.timestamp() + + i = DateTime(2022, 7, 10, tzinfo=pendulum.tz.UTC) + s = serialize(i) + d = deserialize(s) + assert i.timestamp() == d.timestamp() + + i = datetime.date(2022, 7, 10) + s = serialize(i) + d = deserialize(s) + assert i == d + + i = datetime.timedelta(days=320) + s = serialize(i) + d = deserialize(s) + assert i == d + + @pytest.mark.parametrize( + "expr, expected", + [("1", "1"), ("52e4", "520000"), ("2e0", "2"), ("12e-2", "0.12"), ("12.34", "12.34")], + ) + def test_encode_decimal(self, expr, expected): + assert deserialize(serialize(decimal.Decimal(expr))) == decimal.Decimal(expected) + + def test_encode_k8s_v1pod(self): + from kubernetes.client import models as k8s + + pod = k8s.V1Pod( + metadata=k8s.V1ObjectMeta( + name="foo", + namespace="bar", + ), + spec=k8s.V1PodSpec( + containers=[ + k8s.V1Container( + name="foo", + image="bar", + ) + ] + ), + ) + assert serialize(pod)[DATA] == { + "metadata": {"name": "foo", "namespace": "bar"}, + "spec": {"containers": [{"image": "bar", "name": "foo"}]}, + } + + def test_numpy(self): + i = numpy.int16(10) + e = serialize(i) + d = deserialize(e) + assert i == d + + def test_params(self): + i = ParamsDict({"x": Param(default="value", description="there is a value", key="test")}) + e = serialize(i) + d = deserialize(e) + assert i["x"] == d["x"] diff --git a/tests/serialization/test_serde.py b/tests/serialization/test_serde.py new file mode 100644 index 0000000000000..8a00ed5cf0411 --- /dev/null +++ b/tests/serialization/test_serde.py @@ -0,0 +1,223 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import datetime +import enum +from dataclasses import dataclass +from typing import ClassVar + +import attr +import pytest + +from airflow.datasets import Dataset +from airflow.serialization.serde import ( + CLASSNAME, + DATA, + SCHEMA_ID, + VERSION, + _compile_patterns, + deserialize, + serialize, +) +from airflow.utils.module_loading import qualname +from tests.test_utils.config import conf_vars + + +class Z: + __version__: ClassVar[int] = 1 + + def __init__(self, x): + self.x = x + + def serialize(self) -> dict: + return dict({"x": self.x}) + + @staticmethod + def deserialize(data: dict, version: int): + if version != 1: + raise TypeError("version != 1") + return Z(data["x"]) + + +@attr.define +class Y: + x: int + __version__: ClassVar[int] = 1 + + def __init__(self, x): + self.x = x + + +class X: + pass + + +@dataclass +class W: + __version__: ClassVar[int] = 2 + x: int + + +class TestSerDe: + @pytest.fixture(autouse=True) + def ensure_clean_allow_list(self): + _compile_patterns() + yield + + def test_ser_primitives(self): + i = 10 + e = serialize(i) + assert i == e + + i = 10.1 + e = serialize(i) + assert i == e + + i = "test" + e = serialize(i) + assert i == e + + i = True + e = serialize(i) + assert i == e + + Color = enum.IntEnum("Color", ["RED", "GREEN"]) + i = Color.RED + e = serialize(i) + assert i == e + + def test_ser_iterables(self): + i = [1, 2] + e = serialize(i) + assert i == e + + i = ("a", "b", "a", "c") + e = serialize(i) + assert i == e + + i = {2, 3} + e = serialize(i) + assert i == e + + def test_ser_plain_dict(self): + i = {"a": 1, "b": 2} + e = serialize(i) + assert i == e + + with pytest.raises(AttributeError, match="^reserved"): + i = {CLASSNAME: "cannot"} + serialize(i) + + with pytest.raises(AttributeError, match="^reserved"): + i = {SCHEMA_ID: "cannot"} + serialize(i) + + def test_no_serializer(self): + with pytest.raises(TypeError, match="^cannot serialize"): + i = Exception + serialize(i) + + def test_ser_registered(self): + i = datetime.datetime(2000, 10, 1) + e = serialize(i) + assert e[DATA] + + def test_serder_custom(self): + i = Z(1) + e = serialize(i) + assert Z.__version__ == e[VERSION] + assert qualname(Z) == e[CLASSNAME] + assert e[DATA] + + d = deserialize(e) + assert i.x == getattr(d, "x", None) + + def test_serder_attr(self): + i = Y(10) + e = serialize(i) + assert Y.__version__ == e[VERSION] + assert qualname(Y) == e[CLASSNAME] + assert e[DATA] + + d = deserialize(e) + assert i.x == getattr(d, "x", None) + + def test_serder_dataclass(self): + i = W(12) + e = serialize(i) + assert W.__version__ == e[VERSION] + assert qualname(W) == e[CLASSNAME] + assert e[DATA] + + d = deserialize(e) + assert i.x == getattr(d, "x", None) + + @conf_vars( + { + ("core", "allowed_deserialization_classes"): "airflow[.].*", + } + ) + def test_allow_list_for_imports(self): + _compile_patterns() + i = Z(10) + e = serialize(i) + with pytest.raises(ImportError) as ex: + deserialize(e) + + assert f"{qualname(Z)} was not found in allow list" in str(ex.value) + + def test_incompatible_version(self): + data = dict( + { + "__classname__": Y.__module__ + "." + Y.__qualname__, + "__version__": 2, + } + ) + with pytest.raises(TypeError, match="newer than"): + deserialize(data) + + def test_raise_undeserializable(self): + data = dict( + { + "__classname__": X.__module__ + "." + X.__qualname__, + "__version__": 0, + } + ) + with pytest.raises(TypeError, match="No deserializer"): + deserialize(data) + + def test_backwards_compat(self): + uri = "s3://does_not_exist" + data = { + "__type": "airflow.datasets.Dataset", + "__source": None, + "__var": { + "__var": { + "uri": uri, + "extra": None, + }, + "__type": "dict", + }, + } + dataset = deserialize(data) + assert dataset.uri == uri + + def test_encode_dataset(self): + dataset = Dataset("mytest://dataset") + obj = deserialize(serialize(dataset)) + assert dataset.uri == obj.uri diff --git a/tests/utils/test_json.py b/tests/utils/test_json.py index 2376b5ce2fd36..bcf11dff10990 100644 --- a/tests/utils/test_json.py +++ b/tests/utils/test_json.py @@ -17,112 +17,53 @@ # under the License. from __future__ import annotations -import decimal import json from dataclasses import dataclass from datetime import date, datetime from typing import ClassVar -import attr import numpy as np import pendulum import pytest from airflow.datasets import Dataset from airflow.utils import json as utils_json -from tests.test_utils.config import conf_vars - - -class Z: - version = 1 - - def __init__(self, x): - self.x = x - - def serialize(self) -> dict: - return dict({"x": self.x}) - - @staticmethod - def deserialize(data: dict, version: int): - if version != 1: - raise TypeError("version != 1") - return Z(data["x"]) - - -@attr.define -class Y: - x: int - version: ClassVar[int] = 1 - - def __init__(self, x): - self.x = x - - -class X: - pass @dataclass class U: - version: ClassVar[int] = 2 + __version__: ClassVar[int] = 2 x: int -class TestAirflowJsonEncoder: +class TestWebEncoder: def test_encode_datetime(self): obj = datetime.strptime("2017-05-21 00:00:00", "%Y-%m-%d %H:%M:%S") - assert json.dumps(obj, cls=utils_json.AirflowJsonEncoder) == '"2017-05-21T00:00:00+00:00"' + assert json.dumps(obj, cls=utils_json.WebEncoder) == '"2017-05-21T00:00:00+00:00"' def test_encode_pendulum(self): obj = pendulum.datetime(2017, 5, 21, tz="Asia/Kolkata") - assert json.dumps(obj, cls=utils_json.AirflowJsonEncoder) == '"2017-05-21T00:00:00+05:30"' + assert json.dumps(obj, cls=utils_json.WebEncoder) == '"2017-05-21T00:00:00+05:30"' def test_encode_date(self): - assert json.dumps(date(2017, 5, 21), cls=utils_json.AirflowJsonEncoder) == '"2017-05-21"' - - @pytest.mark.parametrize( - "expr, expected", - [("1", "1"), ("52e4", "520000"), ("2e0", "2"), ("12e-2", "0.12"), ("12.34", "12.34")], - ) - def test_encode_decimal(self, expr, expected): - assert json.dumps(decimal.Decimal(expr), cls=utils_json.AirflowJsonEncoder) == expected + assert json.dumps(date(2017, 5, 21), cls=utils_json.WebEncoder) == '"2017-05-21"' def test_encode_numpy_int(self): - assert json.dumps(np.int32(5), cls=utils_json.AirflowJsonEncoder) == "5" + assert json.dumps(np.int32(5), cls=utils_json.WebEncoder) == "5" def test_encode_numpy_bool(self): - assert json.dumps(np.bool_(True), cls=utils_json.AirflowJsonEncoder) == "true" + assert json.dumps(np.bool_(True), cls=utils_json.WebEncoder) == "true" def test_encode_numpy_float(self): - assert json.dumps(np.float16(3.76953125), cls=utils_json.AirflowJsonEncoder) == "3.76953125" - - def test_encode_k8s_v1pod(self): - from kubernetes.client import models as k8s - - pod = k8s.V1Pod( - metadata=k8s.V1ObjectMeta( - name="foo", - namespace="bar", - ), - spec=k8s.V1PodSpec( - containers=[ - k8s.V1Container( - name="foo", - image="bar", - ) - ] - ), - ) - assert json.loads(json.dumps(pod, cls=utils_json.AirflowJsonEncoder)) == { - "metadata": {"name": "foo", "namespace": "bar"}, - "spec": {"containers": [{"image": "bar", "name": "foo"}]}, - } + assert json.dumps(np.float16(3.76953125), cls=utils_json.WebEncoder) == "3.76953125" + +class TestXComEncoder: def test_encode_raises(self): with pytest.raises(TypeError, match="^.*is not JSON serializable$"): json.dumps( Exception, - cls=utils_json.AirflowJsonEncoder, + cls=utils_json.XComEncoder, ) def test_encode_xcom_dataset(self): @@ -131,103 +72,9 @@ def test_encode_xcom_dataset(self): obj = json.loads(s, cls=utils_json.XComDecoder) assert dataset.uri == obj.uri - def test_backwards_compat(self): - uri = "s3://does_not_exist" - data = { - "__type": "airflow.datasets.Dataset", - "__source": None, - "__var": { - "__var": { - "uri": uri, - "extra": None, - }, - "__type": "dict", - }, - } - decoder = utils_json.XComDecoder() - dataset = decoder.object_hook(data) - assert dataset.uri == uri - - def test_raise_on_reserved(self): - data = {"__classname__": "my.class"} - with pytest.raises(AttributeError): - json.dumps(data, cls=utils_json.XComEncoder) - - def test_custom_serialize(self): - x = 11 - z = Z(x) - s = json.dumps(z, cls=utils_json.XComEncoder) - o = json.loads(s, cls=utils_json.XComDecoder) - - assert o.x == x - - def test_raise_undeserializable(self): - data = dict( - { - "__classname__": X.__module__ + "." + X.__qualname__, - "__version__": 0, - } - ) - s = json.dumps(data) - with pytest.raises(TypeError) as info: - json.loads(s, cls=utils_json.XComDecoder) - - assert "cannot deserialize" in str(info.value) - - def test_attr_version(self): - x = 14 - y = Y(x) - s = json.dumps(y, cls=utils_json.XComEncoder) - o = json.loads(s, cls=utils_json.XComDecoder) - - assert o.x == x - - def test_incompatible_version(self): - data = dict( - { - "__classname__": Y.__module__ + "." + Y.__qualname__, - "__version__": 2, - } - ) - s = json.dumps(data) - with pytest.raises(TypeError) as info: - json.loads(s, cls=utils_json.XComDecoder) - - assert "newer than" in str(info.value) - - def test_dataclass(self): - x = 12 - u = U(x=x) - s = json.dumps(u, cls=utils_json.XComEncoder) - o = json.loads(s, cls=utils_json.XComDecoder) - - assert o.x == x - def test_orm_deserialize(self): x = 14 u = U(x=x) s = json.dumps(u, cls=utils_json.XComEncoder) o = json.loads(s, cls=utils_json.XComDecoder, object_hook=utils_json.XComDecoder.orm_object_hook) - assert o == f"{U.__module__}.{U.__qualname__}@version={U.version}(x={x})" - - def test_orm_custom_deserialize(self): - x = 14 - z = Z(x=x) - s = json.dumps(z, cls=utils_json.XComEncoder) - o = json.loads(s, cls=utils_json.XComDecoder, object_hook=utils_json.XComDecoder.orm_object_hook) - assert o == f"{Z.__module__}.{Z.__qualname__}@version={Z.version}(x={x})" - - @conf_vars( - { - ("core", "allowed_deserialization_classes"): "airflow[.].*", - } - ) - def test_allow_list_for_imports(self): - x = 14 - z = Z(x=x) - s = json.dumps(z, cls=utils_json.XComEncoder) - - with pytest.raises(ImportError) as e: - json.loads(s, cls=utils_json.XComDecoder) - - assert f"{Z.__module__}.{Z.__qualname__} was not found in allow list" in str(e.value) + assert o == f"{U.__module__}.{U.__qualname__}@version={U.__version__}(x={x})"