From 326f227501ebcd3575e83ce413739b7debff3b01 Mon Sep 17 00:00:00 2001 From: Bolke de Bruin Date: Mon, 28 Nov 2022 10:00:38 +0100 Subject: [PATCH 1/6] Generalize serialization and deserialization Seralization and deserialization (derser) of objects that were not under the control of Airflow of the dag author was not possible. This implementation allows for extension of serialization of arbitrary objects by putting them in the namespace of airflow. Selection of serializer happens in order of registered serializer, custom serializer, attr/dataclass annottated object. --- airflow/models/param.py | 25 +- airflow/plugins_manager.py | 21 +- airflow/providers/amazon/aws/hooks/eks.py | 10 +- .../amazon/aws/operators/sagemaker.py | 4 +- airflow/serialization/serde.py | 306 ++++++++++++++++++ airflow/serialization/serialized_objects.py | 26 +- airflow/serialization/serializers/__init__.py | 17 + airflow/serialization/serializers/bignum.py | 55 ++++ airflow/serialization/serializers/datetime.py | 73 +++++ .../serialization/serializers/kubernetes.py | 68 ++++ airflow/serialization/serializers/numpy.py | 104 ++++++ airflow/serialization/serializers/timezone.py | 69 ++++ airflow/utils/json.py | 239 +++----------- airflow/utils/module_loading.py | 29 +- airflow/www/views.py | 2 +- docs/apache-airflow/concepts/taskflow.rst | 6 +- docs/apache-airflow/developer/serializers.rst | 128 ++++++++ docs/apache-airflow/integration.rst | 1 + tests/serialization/serializers/__init__.py | 17 + .../serializers/test_serializers.py | 93 ++++++ tests/serialization/test_serde.py | 223 +++++++++++++ tests/utils/test_json.py | 177 +--------- 22 files changed, 1293 insertions(+), 400 deletions(-) create mode 100644 airflow/serialization/serde.py create mode 100644 airflow/serialization/serializers/__init__.py create mode 100644 airflow/serialization/serializers/bignum.py create mode 100644 airflow/serialization/serializers/datetime.py create mode 100644 airflow/serialization/serializers/kubernetes.py create mode 100644 airflow/serialization/serializers/numpy.py create mode 100644 airflow/serialization/serializers/timezone.py create mode 100644 docs/apache-airflow/developer/serializers.rst create mode 100644 tests/serialization/serializers/__init__.py create mode 100644 tests/serialization/serializers/test_serializers.py create mode 100644 tests/serialization/test_serde.py diff --git a/airflow/models/param.py b/airflow/models/param.py index d944ef20311d3..2f93f7dd88341 100644 --- a/airflow/models/param.py +++ b/airflow/models/param.py @@ -21,7 +21,7 @@ import json import logging import warnings -from typing import TYPE_CHECKING, Any, ItemsView, Iterable, MutableMapping, ValuesView +from typing import TYPE_CHECKING, Any, ClassVar, ItemsView, Iterable, MutableMapping, ValuesView from airflow.exceptions import AirflowException, ParamValidationError, RemovedInAirflow3Warning from airflow.utils.context import Context @@ -47,6 +47,8 @@ class Param: default & description will form the schema """ + __version__: ClassVar[int] = 1 + CLASS_IDENTIFIER = "__class" def __init__(self, default: Any = NOTSET, description: str | None = None, **kwargs): @@ -112,6 +114,16 @@ def dump(self) -> dict: def has_value(self) -> bool: return self.value is not NOTSET + def serialize(self) -> dict: + return {"value": self.value, "description": self.description, "schema": self.schema} + + @staticmethod + def deserialize(data: dict[str, Any], version: int) -> Param: + if version > Param.__version__: + raise TypeError("serialized version > class version") + + return Param(default=data["value"], description=data["description"], schema=data["schema"]) + class ParamsDict(MutableMapping[str, Any]): """ @@ -120,6 +132,7 @@ class ParamsDict(MutableMapping[str, Any]): dictionary implicitly and ideally not needed to be used directly. """ + __version__: ClassVar[int] = 1 __slots__ = ["__dict", "suppress_exception"] def __init__(self, dict_obj: dict | None = None, suppress_exception: bool = False): @@ -231,6 +244,16 @@ def validate(self) -> dict[str, Any]: return resolved_dict + def serialize(self) -> dict[str, Any]: + return self.dump() + + @staticmethod + def deserialize(data: dict, version: int) -> ParamsDict: + if version > ParamsDict.__version__: + raise TypeError("serialized version > class version") + + return ParamsDict(data) + class DagParam(ResolveMixin): """DAG run parameter reference. diff --git a/airflow/plugins_manager.py b/airflow/plugins_manager.py index 1d887110208a0..d5fed0d3a4297 100644 --- a/airflow/plugins_manager.py +++ b/airflow/plugins_manager.py @@ -38,7 +38,7 @@ from airflow import settings from airflow.utils.entry_points import entry_points_with_dist from airflow.utils.file import find_path_from_directory -from airflow.utils.module_loading import as_importable_string +from airflow.utils.module_loading import qualname if TYPE_CHECKING: from airflow.hooks.base import BaseHook @@ -373,7 +373,7 @@ def initialize_ti_deps_plugins(): for plugin in plugins: registered_ti_dep_classes.update( - {as_importable_string(ti_dep.__class__): ti_dep.__class__ for ti_dep in plugin.ti_deps} + {qualname(ti_dep.__class__): ti_dep.__class__ for ti_dep in plugin.ti_deps} ) @@ -406,7 +406,7 @@ def initialize_extra_operators_links_plugins(): operator_extra_links.extend(list(plugin.operator_extra_links)) registered_operator_link_classes.update( - {as_importable_string(link.__class__): link.__class__ for link in plugin.operator_extra_links} + {qualname(link.__class__): link.__class__ for link in plugin.operator_extra_links} ) @@ -425,7 +425,7 @@ def initialize_timetables_plugins(): log.debug("Initialize extra timetables plugins") timetable_classes = { - as_importable_string(timetable_class): timetable_class + qualname(timetable_class): timetable_class for plugin in plugins for timetable_class in plugin.timetables } @@ -525,25 +525,20 @@ def get_plugin_info(attrs_to_dump: Iterable[str] | None = None) -> list[dict[str info: dict[str, Any] = {"name": plugin.name} for attr in attrs_to_dump: if attr in ("global_operator_extra_links", "operator_extra_links"): - info[attr] = [ - f"<{as_importable_string(d.__class__)} object>" for d in getattr(plugin, attr) - ] + info[attr] = [f"<{qualname(d.__class__)} object>" for d in getattr(plugin, attr)] elif attr in ("macros", "timetables", "hooks", "executors"): - info[attr] = [as_importable_string(d) for d in getattr(plugin, attr)] + info[attr] = [qualname(d) for d in getattr(plugin, attr)] elif attr == "listeners": # listeners are always modules info[attr] = [d.__name__ for d in getattr(plugin, attr)] elif attr == "appbuilder_views": info[attr] = [ - {**d, "view": as_importable_string(d["view"].__class__) if "view" in d else None} + {**d, "view": qualname(d["view"].__class__) if "view" in d else None} for d in getattr(plugin, attr) ] elif attr == "flask_blueprints": info[attr] = [ - ( - f"<{as_importable_string(d.__class__)}: " - f"name={d.name!r} import_name={d.import_name!r}>" - ) + f"<{qualname(d.__class__)}: name={d.name!r} import_name={d.import_name!r}>" for d in getattr(plugin, attr) ] else: diff --git a/airflow/providers/amazon/aws/hooks/eks.py b/airflow/providers/amazon/aws/hooks/eks.py index 03488dea96771..d74b2a50536d9 100644 --- a/airflow/providers/amazon/aws/hooks/eks.py +++ b/airflow/providers/amazon/aws/hooks/eks.py @@ -31,7 +31,7 @@ from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook from airflow.utils import yaml -from airflow.utils.json import AirflowJsonEncoder +from airflow.utils.json import WebEncoder DEFAULT_PAGINATION_TOKEN = "" STS_TOKEN_EXPIRES_IN = 60 @@ -276,7 +276,7 @@ def describe_cluster(self, name: str, verbose: bool = False) -> dict: ) if verbose: cluster_data = response.get("cluster") - self.log.info("Amazon EKS cluster details: %s", json.dumps(cluster_data, cls=AirflowJsonEncoder)) + self.log.info("Amazon EKS cluster details: %s", json.dumps(cluster_data, cls=WebEncoder)) return response def describe_nodegroup(self, clusterName: str, nodegroupName: str, verbose: bool = False) -> dict: @@ -302,7 +302,7 @@ def describe_nodegroup(self, clusterName: str, nodegroupName: str, verbose: bool nodegroup_data = response.get("nodegroup") self.log.info( "Amazon EKS managed node group details: %s", - json.dumps(nodegroup_data, cls=AirflowJsonEncoder), + json.dumps(nodegroup_data, cls=WebEncoder), ) return response @@ -331,9 +331,7 @@ def describe_fargate_profile( ) if verbose: fargate_profile_data = response.get("fargateProfile") - self.log.info( - "AWS Fargate profile details: %s", json.dumps(fargate_profile_data, cls=AirflowJsonEncoder) - ) + self.log.info("AWS Fargate profile details: %s", json.dumps(fargate_profile_data, cls=WebEncoder)) return response def get_cluster_state(self, clusterName: str) -> ClusterStates: diff --git a/airflow/providers/amazon/aws/operators/sagemaker.py b/airflow/providers/amazon/aws/operators/sagemaker.py index 4b969002b339e..3acba423eb373 100644 --- a/airflow/providers/amazon/aws/operators/sagemaker.py +++ b/airflow/providers/amazon/aws/operators/sagemaker.py @@ -26,7 +26,7 @@ from airflow.models import BaseOperator from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook -from airflow.utils.json import AirflowJsonEncoder +from airflow.utils.json import WebEncoder if TYPE_CHECKING: from airflow.utils.context import Context @@ -36,7 +36,7 @@ def serialize(result: dict) -> str: - return json.loads(json.dumps(result, cls=AirflowJsonEncoder)) + return json.loads(json.dumps(result, cls=WebEncoder)) class SageMakerBaseOperator(BaseOperator): diff --git a/airflow/serialization/serde.py b/airflow/serialization/serde.py new file mode 100644 index 0000000000000..401edaaa79917 --- /dev/null +++ b/airflow/serialization/serde.py @@ -0,0 +1,306 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import dataclasses +import enum +import logging +import re +import sys +from importlib import import_module +from types import ModuleType +from typing import Any, TypeVar, Union + +import attr + +import airflow.serialization.serializers +from airflow.configuration import conf +from airflow.utils.module_loading import import_string, iter_namespace, qualname + +log = logging.getLogger(__name__) + +MAX_RECURSION_DEPTH = sys.getrecursionlimit() - 1 + +CLASSNAME = "__classname__" +VERSION = "__version__" +DATA = "__data__" +SCHEMA_ID = "__id__" +CACHE = "__cache__" + +OLD_TYPE = "__type" +OLD_SOURCE = "__source" +OLD_DATA = "__var" + +DEFAULT_VERSION = 0 + +T = TypeVar("T", bool, float, int, dict, list, str, tuple, set) +U = Union[bool, float, int, dict, list, str, tuple, set] +S = Union[list, tuple, set] + +_serializers: dict[str, ModuleType] = {} +_deserializers: dict[str, ModuleType] = {} +_extra_allowed: set[str] = set() + +_primitives = (int, bool, float, str) +_iterables = (list, set, tuple) +_patterns: list[re.Pattern] = [] + +_reverse_cache: dict[int, tuple[ModuleType, str, int]] = {} + + +def encode(cls: str, version: int, data: T) -> dict[str, str | int | T]: + """Encodes o so it can be understood by the deserializer""" + return {CLASSNAME: cls, VERSION: version, DATA: data} + + +def decode(d: dict[str, str | int | T]) -> tuple: + return d[CLASSNAME], d[VERSION], d.get(DATA, None) + + +def serialize(o: object, depth: int = 0) -> U | None: + """ + Recursively serializes objects into a primitive. Primitives (int, float, int, bool) + are returned as is. Tuples and dicts are iterated over, where it is assumed that keys + for dicts can be represented as str. Values that are not primitives are serialized if + a serializer is found for them. The order in which serializers are used + is 1) a serialize function provided by the object 2) a registered serializer in + the namespace of airflow.serialization.serializers and 3) an attr or dataclass annotations. + If a serializer cannot be found a TypeError is raised. + + :param o: object to serialize + :param depth: private + :return: a primitive + """ + if depth == MAX_RECURSION_DEPTH: + raise RecursionError("maximum recursion depth reached for serialization") + + # None remains None + if o is None: + return o + + # primitive types are returned as is + if isinstance(o, _primitives): + if isinstance(o, enum.Enum): + return o.value + + return o + + # tuples and plain dicts are iterated over recursively + if isinstance(o, _iterables): + s = [serialize(d, depth + 1) for d in o] + if isinstance(o, tuple): + return tuple(s) + if isinstance(o, set): + return set(s) + return s + + if isinstance(o, dict): + if CLASSNAME in o or SCHEMA_ID in o: + raise AttributeError(f"reserved key {CLASSNAME} or {SCHEMA_ID} found in dict to serialize") + + return {str(k): serialize(v, depth + 1) for k, v in o.items()} + + cls = type(o) + qn = qualname(o) + + # custom serializers + dct = { + CLASSNAME: qn, + VERSION: getattr(cls, "__version__", DEFAULT_VERSION), + } + + # if there is a builtin serializer available use that + if qn in _serializers: + data, classname, version, is_serialized = _serializers[qn].serialize(o) + if is_serialized: + return encode(classname, version, serialize(data, depth + 1)) + + # object / class brings their own + if hasattr(o, "serialize"): + data = getattr(o, "serialize")() + + # if we end up with a structure, ensure its values are serialized + if isinstance(data, dict): + data = serialize(data, depth + 1) + + dct[DATA] = data + return dct + + # dataclasses + if dataclasses.is_dataclass(cls): + data = dataclasses.asdict(o) + dct[DATA] = serialize(data, depth + 1) + return dct + + # attr annotated + if attr.has(cls): + # Only include attributes which we can pass back to the classes constructor + data = attr.asdict(o, recurse=True, filter=lambda a, v: a.init) # type: ignore[arg-type] + dct[DATA] = serialize(data, depth + 1) + return dct + + raise TypeError(f"cannot serialize object of type {cls}") + + +def deserialize(o: T | None, full=True, type_hint: Any = None) -> object: + """ + Deserializes an object of primitive type T into an object. Uses an allow + list to determine if a class can be loaded. + + :param o: primitive to deserialize into an arbitrary object. + :param full: if False it will return a stringified representation + of an object and will not load any classes + :param type_hint: if set it will be used to help determine what + object to deserialize in. It does not override if another + specification is found + :return: object + """ + if o is None: + return o + + if isinstance(o, _primitives): + return o + + if isinstance(o, _iterables): + return [deserialize(d) for d in o] + + if not isinstance(o, dict): + raise TypeError() + + o = _convert(o) + + # plain dict and no type hint + if CLASSNAME not in o and not type_hint or VERSION not in o: + return {str(k): deserialize(v, full) for k, v in o.items()} + + # custom deserialization starts here + cls: Any + version = 0 + value: Any + classname: str + + if type_hint: + cls = type_hint + classname = qualname(cls) + version = 0 # type hinting always sets version to 0 + value = o + + if CLASSNAME in o and VERSION in o: + classname, version, value = decode(o) + if not _match(classname) and classname not in _extra_allowed: + raise ImportError( + f"{classname} was not found in allow list for deserialization imports." + f"To allow it, add it to allowed_deserialization_classes in the configuration" + ) + + if full: + cls = import_string(classname) + + # only return string representation + if not full: + return _stringify(classname, version, value) + + # registered deserializer + if classname in _deserializers: + return _deserializers[classname].deserialize(classname, version, deserialize(value)) + + # class has deserialization function + if hasattr(cls, "deserialize"): + return getattr(cls, "deserialize")(deserialize(value), version) + + # attr or dataclass + if attr.has(cls) or dataclasses.is_dataclass(cls): + class_version = getattr(cls, "__version__", 0) + if int(version) > class_version: + raise TypeError( + "serialized version of %s is newer than module version (%s > %s)", + classname, + version, + class_version, + ) + + return cls(**deserialize(value)) + + # no deserializer available + raise TypeError(f"No deserializer found for {classname}") + + +def _convert(old: dict) -> dict: + """Converts an old style serialization to new style""" + if OLD_TYPE in old and OLD_DATA in old: + return {CLASSNAME: old[OLD_TYPE], VERSION: DEFAULT_VERSION, DATA: old[OLD_DATA][OLD_DATA]} + + return old + + +def _match(classname: str) -> bool: + for p in _patterns: + if p.match(classname): + return True + + return False + + +def _stringify(classname: str, version: int, value: T | None) -> str: + s = f"{classname}@version={version}(" + if isinstance(value, _primitives): + s += f"{value})" + elif isinstance(value, _iterables): + s += ",".join(str(serialize(value, False))) + elif isinstance(value, dict): + for k, v in value.items(): + s += f"{k}={str(serialize(v, False))}," + s = s[:-1] + ")" + + return s + + +def _register(): + """Register builtin serializers and deserializers for types that don't have any themselves""" + _serializers.clear() + _deserializers.clear() + + for _, name, _ in iter_namespace(airflow.serialization.serializers): + name = import_module(name) + for s in getattr(name, "serializers", list()): + if not isinstance(s, str): + s = qualname(s) + if s in _serializers and _serializers[s] != name: + raise AttributeError(f"duplicate {s} for serialization in {name} and {_serializers[s]}") + log.debug("registering %s for serialization") + _serializers[s] = name + for d in getattr(name, "deserializers", list()): + if not isinstance(d, str): + d = qualname(d) + if d in _deserializers and _deserializers[d] != name: + raise AttributeError(f"duplicate {d} for deserialization in {name} and {_serializers[d]}") + log.debug("registering %s for deserialization", d) + _deserializers[d] = name + _extra_allowed.add(d) + + +def _compile_patterns(): + patterns = conf.get("core", "allowed_deserialization_classes").split() + + _patterns.clear() # ensure to reinit + for p in patterns: + _patterns.append(re.compile(p)) + + +_register() +_compile_patterns() diff --git a/airflow/serialization/serialized_objects.py b/airflow/serialization/serialized_objects.py index 4563e570dddd6..d6340c4b512a9 100644 --- a/airflow/serialization/serialized_objects.py +++ b/airflow/serialization/serialized_objects.py @@ -55,7 +55,7 @@ from airflow.timetables.base import Timetable from airflow.utils.code_utils import get_python_source from airflow.utils.docs import get_docs_url -from airflow.utils.module_loading import as_importable_string, import_string +from airflow.utils.module_loading import import_string, qualname from airflow.utils.operator_resources import Resources from airflow.utils.task_group import MappedTaskGroup, TaskGroup @@ -182,7 +182,7 @@ def _encode_timetable(var: Timetable) -> dict[str, Any]: can be completely controlled by a custom subclass. """ timetable_class = type(var) - importable_string = as_importable_string(timetable_class) + importable_string = qualname(timetable_class) if _get_registered_timetable(importable_string) is None: raise _TimetableNotRegistered(importable_string) return {Encoding.TYPE: importable_string, Encoding.VAR: var.serialize()} @@ -1005,19 +1005,19 @@ def _deserialize_deps(cls, deps: list[str]) -> set[BaseTIDep]: raise AirflowException("Can not load plugins") instances = set() - for qualname in set(deps): + for qn in set(deps): if ( - not qualname.startswith("airflow.ti_deps.deps.") - and qualname not in plugins_manager.registered_ti_dep_classes + not qn.startswith("airflow.ti_deps.deps.") + and qn not in plugins_manager.registered_ti_dep_classes ): raise SerializationError( - f"Custom dep class {qualname} not deserialized, please register it through plugins." + f"Custom dep class {qn} not deserialized, please register it through plugins." ) try: - instances.add(import_string(qualname)()) + instances.add(import_string(qn)()) except ImportError: - log.warning("Error importing dep %r", qualname, exc_info=True) + log.warning("Error importing dep %r", qn, exc_info=True) return instances @classmethod @@ -1108,6 +1108,15 @@ def _serialize_operator_extra_links(cls, operator_extra_links: Iterable[BaseOper return serialize_operator_extra_links + @classmethod + def serialize(cls, var: Any, *, strict: bool = False) -> Any: + # the wonders of multiple inheritance BaseOperator defines an instance method + return BaseSerialization.serialize(var=var, strict=strict) + + @classmethod + def deserialize(cls, encoded_var: Any) -> Any: + return BaseSerialization.deserialize(encoded_var=encoded_var) + class SerializedDAG(DAG, BaseSerialization): """ @@ -1159,6 +1168,7 @@ def serialize_dag(cls, dag: DAG) -> dict: del serialized_dag["timetable"] serialized_dag["tasks"] = [cls.serialize(task) for _, task in dag.task_dict.items()] + dag_deps = { dep for task in dag.task_dict.values() diff --git a/airflow/serialization/serializers/__init__.py b/airflow/serialization/serializers/__init__.py new file mode 100644 index 0000000000000..217e5db960782 --- /dev/null +++ b/airflow/serialization/serializers/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/airflow/serialization/serializers/bignum.py b/airflow/serialization/serializers/bignum.py new file mode 100644 index 0000000000000..972963467d8cb --- /dev/null +++ b/airflow/serialization/serializers/bignum.py @@ -0,0 +1,55 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from decimal import Decimal +from typing import TYPE_CHECKING + +from airflow.utils.module_loading import qualname + +if TYPE_CHECKING: + from airflow.serialization.serde import U + + +serializers = [Decimal] +deserializers = serializers + +__version__ = 1 + + +def serialize(o: object) -> tuple[U, str, int, bool]: + if isinstance(o, Decimal): + name = qualname(o) + _, _, exponent = o.as_tuple() + if exponent >= 0: # No digits after the decimal point. + return int(o), name, __version__, True + # Technically lossy due to floating point errors, but the best we + # can do without implementing a custom encode function. + return float(o), name, __version__, True + + return "", "", 0, False + + +def deserialize(classname: str, version: int, data: object) -> Decimal: + if version > __version__: + raise TypeError(f"serialized {version} of {classname} > {__version__}") + + if classname != qualname(Decimal): + raise TypeError(f"{classname} != {qualname(Decimal)}") + + return Decimal(str(data)) diff --git a/airflow/serialization/serializers/datetime.py b/airflow/serialization/serializers/datetime.py new file mode 100644 index 0000000000000..b400258838a0d --- /dev/null +++ b/airflow/serialization/serializers/datetime.py @@ -0,0 +1,73 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from datetime import date, datetime, timedelta +from typing import TYPE_CHECKING + +from pendulum import DateTime +from pendulum.tz import timezone + +from airflow.utils.module_loading import qualname +from airflow.utils.timezone import convert_to_utc, is_naive + +if TYPE_CHECKING: + from airflow.serialization.serde import U + +__version__ = 1 + +serializers = [date, datetime, timedelta, DateTime] +deserializers = serializers + +TIMESTAMP = "timestamp" +TIMEZONE = "tz" + + +def serialize(o: object) -> tuple[U, str, int, bool]: + if isinstance(o, DateTime) or isinstance(o, datetime): + qn = qualname(o) + if is_naive(o): + o = convert_to_utc(o) + + tz = o.tzname() + + return {TIMESTAMP: o.timestamp(), TIMEZONE: tz}, qn, __version__, True + + if isinstance(o, date): + return o.isoformat(), qualname(o), __version__, True + + if isinstance(o, timedelta): + return o.total_seconds(), qualname(o), __version__, True + + return "", "", 0, False + + +def deserialize(classname: str, version: int, data: dict | str) -> datetime | timedelta | date: + if classname == qualname(datetime) and isinstance(data, dict): + return datetime.fromtimestamp(float(data[TIMESTAMP]), tz=timezone(data[TIMEZONE])) + + if classname == qualname(DateTime) and isinstance(data, dict): + return DateTime.fromtimestamp(float(data[TIMESTAMP]), tz=timezone(data[TIMEZONE])) + + if classname == qualname(timedelta) and isinstance(data, (str, float)): + return timedelta(seconds=float(data)) + + if classname == qualname(date) and isinstance(data, str): + return date.fromisoformat(data) + + raise TypeError(f"unknown date/time format {classname}") diff --git a/airflow/serialization/serializers/kubernetes.py b/airflow/serialization/serializers/kubernetes.py new file mode 100644 index 0000000000000..a8f0c0f333de5 --- /dev/null +++ b/airflow/serialization/serializers/kubernetes.py @@ -0,0 +1,68 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING + +from airflow.utils.module_loading import qualname + +serializers = [] + +try: + from kubernetes.client import models as k8s + + serializers = [k8s.v1_pod.V1Pod, k8s.V1ResourceRequirements] +except ImportError: + k8s = None + +if TYPE_CHECKING: + from airflow.serialization.serde import U + + +__version__ = 1 + +deserializers: list[type[object]] = [] +log = logging.getLogger(__name__) + + +def serialize(o: object) -> tuple[U, str, int, bool]: + if not k8s: + return "", "", 0, False + + if isinstance(o, (k8s.V1Pod, k8s.V1ResourceRequirements)): + from airflow.kubernetes.pod_generator import PodGenerator + + def safe_get_name(pod): + """ + We're running this in an except block, so we don't want it to + fail under any circumstances, e.g. by accessing an attribute that isn't there + """ + try: + return pod.metadata.name + except Exception: + return None + + try: + return PodGenerator.serialize_pod(o), qualname(o), __version__, True + except Exception: + log.warning("Serialization failed for pod %s", safe_get_name(o)) + log.debug("traceback for serialization error", exc_info=True) + return "", "", 0, False + + return "", "", 0, False diff --git a/airflow/serialization/serializers/numpy.py b/airflow/serialization/serializers/numpy.py new file mode 100644 index 0000000000000..4dea70aa13742 --- /dev/null +++ b/airflow/serialization/serializers/numpy.py @@ -0,0 +1,104 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +from airflow.utils.module_loading import qualname + +serializers = [] + +try: + import numpy as np + + serializers = [ + np.int_, + np.intc, + np.intp, + np.int8, + np.int16, + np.int32, + np.int64, + np.uint8, + np.uint16, + np.uint32, + np.uint64, + np.bool_, + np.float_, + np.float16, + np.float64, + np.complex_, + np.complex64, + np.complex128, + ] +except ImportError: + np = None # type: ignore + + +if TYPE_CHECKING: + from airflow.serialization.serde import U + +deserializers: list = serializers +_deserializers: dict[str, type[object]] = {qualname(x): x for x in deserializers} + +__version__ = 1 + + +def serialize(o: object) -> tuple[U, str, int, bool]: + if np is None: + return "", "", 0, False + + name = qualname(o) + if isinstance( + o, + ( + np.int_, + np.intc, + np.intp, + np.int8, + np.int16, + np.int32, + np.int64, + np.uint8, + np.uint16, + np.uint32, + np.uint64, + ), + ): + return int(o), name, __version__, True + + if isinstance(o, np.bool_): + return bool(np), name, __version__, True + + if isinstance( + o, (np.float_, np.float16, np.float32, np.float64, np.complex_, np.complex64, np.complex128) + ): + return float(o), name, __version__, True + + return "", "", 0, False + + +def deserialize(classname: str, version: int, data: str) -> Any: + if version > __version__: + raise TypeError("serialized version is newer than class version") + + f = _deserializers.get(classname, None) + if callable(f): + return f(data) # type: ignore [call-arg] + + raise TypeError(f"unsupported {classname} found for numpy deserialization") diff --git a/airflow/serialization/serializers/timezone.py b/airflow/serialization/serializers/timezone.py new file mode 100644 index 0000000000000..0a0bcc222c611 --- /dev/null +++ b/airflow/serialization/serializers/timezone.py @@ -0,0 +1,69 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from typing import TYPE_CHECKING + +import pendulum +from pendulum.tz.timezone import FixedTimezone, Timezone + +from airflow.utils.module_loading import qualname + +if TYPE_CHECKING: + from airflow.serialization.serde import U + + +serializers = [FixedTimezone, Timezone] +deserializers = serializers + +__version__ = 1 + + +def serialize(o: object) -> tuple[U, str, int, bool]: + """Encode a Pendulum Timezone for serialization. + + Airflow only supports timezone objects that implements Pendulum's Timezone + interface. We try to keep as much information as possible to make conversion + round-tripping possible (see ``decode_timezone``). We need to special-case + UTC; Pendulum implements it as a FixedTimezone (i.e. it gets encoded as + 0 without the special case), but passing 0 into ``pendulum.timezone`` does + not give us UTC (but ``+00:00``). + """ + name = qualname(o) + if isinstance(o, FixedTimezone): + if o.offset == 0: + return "UTC", name, __version__, True + return o.offset, name, __version__, True + + if isinstance(o, Timezone): + return o.name, name, __version__, True + + return "", "", 0, False + + +def deserialize(classname: str, version: int, data: object) -> Timezone: + if not isinstance(data, (str, int)): + raise TypeError(f"{data} is not of type int or str but of {type(data)}") + + if version > __version__: + raise TypeError(f"serialized {version} of {classname} > {__version__}") + + if isinstance(data, int): + return pendulum.tz.fixed_timezone(data) + + return pendulum.tz.timezone(data) diff --git a/airflow/utils/json.py b/airflow/utils/json.py index 16f36cb475f58..10806c12bffe4 100644 --- a/airflow/utils/json.py +++ b/airflow/utils/json.py @@ -17,118 +17,19 @@ # under the License. from __future__ import annotations -import dataclasses import json -import logging -import re from datetime import date, datetime from decimal import Decimal from typing import Any -import attr from flask.json.provider import JSONProvider -from airflow.configuration import conf -from airflow.serialization.enums import Encoding -from airflow.utils.module_loading import import_string +from airflow.serialization.serde import CLASSNAME, DATA, SCHEMA_ID, deserialize, serialize from airflow.utils.timezone import convert_to_utc, is_naive -try: - import numpy as np -except ImportError: - np = None # type: ignore - -try: - from kubernetes.client import models as k8s -except ImportError: - k8s = None - -# Dates and JSON encoding/decoding - -log = logging.getLogger(__name__) - -CLASSNAME = "__classname__" -VERSION = "__version__" -DATA = "__data__" - -OLD_TYPE = "__type" -OLD_SOURCE = "__source" -OLD_DATA = "__var" - -DEFAULT_VERSION = 0 - - -class AirflowJsonEncoder(json.JSONEncoder): - """Custom Airflow json encoder implementation.""" - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.default = self._default - - @staticmethod - def _default(obj): - """Convert dates and numpy objects in a json serializable format.""" - if isinstance(obj, datetime): - if is_naive(obj): - obj = convert_to_utc(obj) - return obj.isoformat() - elif isinstance(obj, date): - return obj.strftime("%Y-%m-%d") - elif isinstance(obj, Decimal): - _, _, exponent = obj.as_tuple() - if exponent >= 0: # No digits after the decimal point. - return int(obj) - # Technically lossy due to floating point errors, but the best we - # can do without implementing a custom encode function. - return float(obj) - elif np is not None and isinstance( - obj, - ( - np.int_, - np.intc, - np.intp, - np.int8, - np.int16, - np.int32, - np.int64, - np.uint8, - np.uint16, - np.uint32, - np.uint64, - ), - ): - return int(obj) - elif np is not None and isinstance(obj, np.bool_): - return bool(obj) - elif np is not None and isinstance( - obj, (np.float_, np.float16, np.float32, np.float64, np.complex_, np.complex64, np.complex128) - ): - return float(obj) - elif k8s is not None and isinstance(obj, (k8s.V1Pod, k8s.V1ResourceRequirements)): - from airflow.kubernetes.pod_generator import PodGenerator - - def safe_get_name(pod): - """ - We're running this in an except block, so we don't want it to - fail under any circumstances, e.g. by accessing an attribute that isn't there - """ - try: - return pod.metadata.name - except Exception: - return None - - try: - return PodGenerator.serialize_pod(obj) - except Exception: - log.warning("JSON encoding failed for pod %s", safe_get_name(obj)) - log.debug("traceback for pod JSON encode error", exc_info=True) - return {} - - raise TypeError(f"Object of type '{obj.__class__.__qualname__}' is not JSON serializable") - class AirflowJsonProvider(JSONProvider): - """JSON Provider for Flask app to use AirflowJsonEncoder.""" + """JSON Provider for Flask app to use WebEncoder.""" ensure_ascii: bool = True sort_keys: bool = True @@ -136,41 +37,55 @@ class AirflowJsonProvider(JSONProvider): def dumps(self, obj, **kwargs): kwargs.setdefault("ensure_ascii", self.ensure_ascii) kwargs.setdefault("sort_keys", self.sort_keys) - return json.dumps(obj, **kwargs, cls=AirflowJsonEncoder) + return json.dumps(obj, **kwargs, cls=WebEncoder) def loads(self, s: str | bytes, **kwargs): return json.loads(s, **kwargs) -# for now separate as AirflowJsonEncoder is non-standard -class XComEncoder(json.JSONEncoder): - """This encoder serializes any object that has attr, dataclass or a custom serializer.""" +class WebEncoder(json.JSONEncoder): + """This encodes values into a web understandable format. There is no deserializer""" + + def default(self, o: Any) -> Any: + if isinstance(o, datetime): + if is_naive(o): + o = convert_to_utc(o) + return o.isoformat() + + if isinstance(o, date): + return o.strftime("%Y-%m-%d") + + if isinstance(o, Decimal): + data = serialize(o) + if isinstance(data, dict) and DATA in data: + return data[DATA] - def default(self, o: object) -> dict: - from airflow.serialization.serialized_objects import BaseSerialization + try: + data = serialize(o) + if isinstance(data, dict) and CLASSNAME in data: + # this is here for backwards compatibility + if ( + data[CLASSNAME].startswith("numpy") + or data[CLASSNAME] == "kubernetes.client.models.v1_pod.V1Pod" + ): + return data[DATA] + return data + except TypeError: + raise - dct = { - CLASSNAME: o.__module__ + "." + o.__class__.__qualname__, - VERSION: getattr(o.__class__, "version", DEFAULT_VERSION), - } - if hasattr(o, "serialize"): - dct[DATA] = getattr(o, "serialize")() - return dct - elif dataclasses.is_dataclass(o.__class__): - data = dataclasses.asdict(o) - dct[DATA] = BaseSerialization.serialize(data) - return dct - elif attr.has(o.__class__): - # Only include attributes which we can pass back to the classes constructor - data = attr.asdict(o, recurse=True, filter=lambda a, v: a.init) # type: ignore[arg-type] - dct[DATA] = BaseSerialization.serialize(data) - return dct - else: +class XComEncoder(json.JSONEncoder): + """This encoder serializes any object that has attr, dataclass or a custom serializer.""" + + def default(self, o: object) -> Any: + try: + return serialize(o) + except TypeError: return super().default(o) def encode(self, o: Any) -> str: - if isinstance(o, dict) and CLASSNAME in o: + # checked here and in serialize + if isinstance(o, dict) and (CLASSNAME in o or SCHEMA_ID in o): raise AttributeError(f"reserved key {CLASSNAME} found in dict to serialize") return super().encode(o) @@ -183,86 +98,16 @@ class XComDecoder(json.JSONDecoder): as is. """ - _pattern: list[re.Pattern] = [] - def __init__(self, *args, **kwargs) -> None: if not kwargs.get("object_hook"): kwargs["object_hook"] = self.object_hook - patterns = conf.get("core", "allowed_deserialization_classes").split() - - self._pattern.clear() # ensure to reinit - for p in patterns: - self._pattern.append(re.compile(p)) - super().__init__(*args, **kwargs) def object_hook(self, dct: dict) -> object: - dct = XComDecoder._convert(dct) - - if CLASSNAME in dct and VERSION in dct: - from airflow.serialization.serialized_objects import BaseSerialization - - classname = dct[CLASSNAME] - cls = None - - for p in self._pattern: - if p.match(classname): - cls = import_string(classname) - break - - if not cls: - raise ImportError(f"{classname} was not found in allow list for import") - - if hasattr(cls, "deserialize"): - return getattr(cls, "deserialize")(dct[DATA], dct[VERSION]) - - version = getattr(cls, "version", 0) - if int(dct[VERSION]) > version: - raise TypeError( - "serialized version of %s is newer than module version (%s > %s)", - dct[CLASSNAME], - dct[VERSION], - version, - ) - - if not attr.has(cls) and not dataclasses.is_dataclass(cls): - raise TypeError( - f"cannot deserialize: no deserialization method " - f"for {dct[CLASSNAME]} and not attr/dataclass decorated" - ) - - return cls(**BaseSerialization.deserialize(dct[DATA])) - - return dct + return deserialize(dct) @staticmethod def orm_object_hook(dct: dict) -> object: """Creates a readable representation of a serialized object""" - dct = XComDecoder._convert(dct) - if CLASSNAME in dct and VERSION in dct: - from airflow.serialization.serialized_objects import BaseSerialization - - if Encoding.VAR in dct[DATA] and Encoding.TYPE in dct[DATA]: - data = BaseSerialization.deserialize(dct[DATA]) - if not isinstance(data, dict): - raise TypeError(f"deserialized value should be a dict, but is {type(data)}") - else: - # custom serializer - data = dct[DATA] - - s = f"{dct[CLASSNAME]}@version={dct[VERSION]}(" - for k, v in data.items(): - s += f"{k}={v}," - s = s[:-1] + ")" - return s - - return dct - - @staticmethod - def _convert(old: dict) -> dict: - """Converts an old style serialization to new style""" - if OLD_TYPE in old and OLD_SOURCE in old: - return {CLASSNAME: old[OLD_TYPE], VERSION: DEFAULT_VERSION, DATA: old[OLD_DATA]} - - return old + return deserialize(dct, False) diff --git a/airflow/utils/module_loading.py b/airflow/utils/module_loading.py index dc361f9f95f89..3053a2ed0445f 100644 --- a/airflow/utils/module_loading.py +++ b/airflow/utils/module_loading.py @@ -17,10 +17,13 @@ # under the License. from __future__ import annotations +import pkgutil from importlib import import_module +from types import ModuleType +from typing import Callable -def import_string(dotted_path): +def import_string(dotted_path: str): """ Import a dotted module path and return the attribute/class designated by the last name in the path. Raise ImportError if the import failed. @@ -38,6 +41,24 @@ def import_string(dotted_path): raise ImportError(f'Module "{module_path}" does not define a "{class_name}" attribute/class') -def as_importable_string(thing) -> str: - """Convert an attribute/class to a string importable by ``import_string``.""" - return f"{thing.__module__}.{thing.__name__}" +def qualname(o: object | Callable) -> str: + """Convert an attribute/class/function to a string importable by ``import_string``.""" + if callable(o): + return f"{o.__module__}.{o.__name__}" + + cls = o + + if not isinstance(cls, type): # instance or class + cls = type(cls) + + name = cls.__qualname__ + module = cls.__module__ + + if module and module != "__builtin__": + return f"{module}.{name}" + + return name + + +def iter_namespace(ns: ModuleType): + return pkgutil.iter_modules(ns.__path__, ns.__name__ + ".") diff --git a/airflow/www/views.py b/airflow/www/views.py index 54800266bebd5..16b4818d7dbdb 100644 --- a/airflow/www/views.py +++ b/airflow/www/views.py @@ -3739,7 +3739,7 @@ def datasets_summary(self): data = {"datasets": datasets, "total_entries": count_query.scalar()} return ( - htmlsafe_json_dumps(data, separators=(",", ":"), cls=utils_json.AirflowJsonEncoder), + htmlsafe_json_dumps(data, separators=(",", ":"), cls=utils_json.WebEncoder), {"Content-Type": "application/json; charset=utf-8"}, ) diff --git a/docs/apache-airflow/concepts/taskflow.rst b/docs/apache-airflow/concepts/taskflow.rst index 7efa38b8430c7..8cc87ee09bb55 100644 --- a/docs/apache-airflow/concepts/taskflow.rst +++ b/docs/apache-airflow/concepts/taskflow.rst @@ -156,7 +156,7 @@ yourself. To do so add the ``serialize()`` method to your class and the staticme class MyCustom: - version: ClassVar[int] = 1 + __version__: ClassVar[int] = 1 def __init__(self, x): self.x = x @@ -174,13 +174,13 @@ Object Versioning ^^^^^^^^^^^^^^^^^ It is good practice to version the objects that will be used in serialization. To do this add -``version: ClassVar[int] = `` to your class. Airflow assumes that your classes are backwards compatible, +``__version__: ClassVar[int] = `` to your class. Airflow assumes that your classes are backwards compatible, so that a version 2 is able to deserialize a version 1. In case you need custom logic for deserialization ensure that ``deserialize(data: dict, version: int)`` is specified. :: - Note: Typing of ``version`` is required and needs to be ``ClassVar[int]`` + Note: Typing of ``__version__`` is required and needs to be ``ClassVar[int]`` History ------- diff --git a/docs/apache-airflow/developer/serializers.rst b/docs/apache-airflow/developer/serializers.rst new file mode 100644 index 0000000000000..69fcd8a155e18 --- /dev/null +++ b/docs/apache-airflow/developer/serializers.rst @@ -0,0 +1,128 @@ + .. Licensed to the Apache Software Foundation (ASF) under one +or more contributor license agreements. See the NOTICE file +distributed with this work for additional information +regarding copyright ownership. The ASF licenses this file +to you under the Apache License, Version 2.0 (the +"License"); you may not use this file except in compliance +with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, +software distributed under the License is distributed on an +"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +KIND, either express or implied. See the License for the +specific language governing permissions and limitations +under the License. + +Serialization +============= + +To support data exchange, like arguments, between tasks, Airflow needs to serialize the data to be exchanged and +deserialize it again when required in a downstream task. Serialization also happens so that the webserver and +the scheduler (as opposed to the dag processor) do no need to read the DAG file. This is done for security purposes +and efficiency. + +Serialization is a surprisingly hard job. Python out of the box only has support for serialization of primitives, +like ``str`` and ``int`` and it loops over iterables. When things become more complex, custom serialization is required. + +Airflow out of the box supports three ways of custom serialization. Primitives are are returned as is, without +additional encoding, e.g. a ``str`` remains a ``str``. When it is not a primitive (or iterable thereof) Airflow +looks for a registered serializer and deserializer in the namespace of ``airflow.serialization.serializers``. +If not found it will look in the class for a ``serialize()`` method or in case of deserialization a +``deserialize(data, version: int)`` method. Finally, if the class is either decorated with ``@dataclass`` +or ``@attr.define`` it will use the public methods for those decorators. + +If you are looking to extend Airflow with a new serializer, it is good to know when to choose what way of serialization. +Objects that are under the control of Airflow, i.e. residing under the namespace of ``airflow.*`` like +``airflow.model.dag.DAG`` or under control of the developer e.g. ``my.company.Foo`` should first be examined to see +whether they can be decorated with ``@attr.define`` or ``@dataclass``. If that is not possible then the ``serialize`` +and ``deserialize`` methods should be implemented. The ``serialize`` method should return a primitive or a dict. +It does not need to serialize the values in the dict, that will be taken care of, but the keys should be of a primitive +form. + +Objects that are not under control of Airflow, e.g. ``numpy.int16`` will need a registered serializer and deserializer. +Versioning is required. Primitives can be returned as can dicts. Again ``dict`` values do not need to be serialized, +but its keys need to be of primitive form. In case you are implementing a registered serializer, take special care +not to have circular imports. Typically, this can be avoided by using ``str`` for populating the list of serializers. +Like so: ``serializers = ["my.company.Foo"]`` instead of ``serializers = [Foo]``. + +:: + + Note: Serialization and deserialization is dependent on speed. Use built-in functions like ``dict`` as much as you can and stay away from using classes and other complex structures. + +Airflow Object +-------------- + +.. code-block:: python + + from typing import Any, ClassVar + + + class Foo: + __version__: ClassVar[int] = 1 + + def __init__(self, a, v) -> None: + self.a = a + self.b = {"x": v} + + def serialize(self) -> dict[str, Any]: + return { + "a": self.a, + "b": self.b, + } + + @staticmethod + def deserialize(data: dict[str, Any], version: int): + f = Foo(a=data["a"]) + f.b = data["b"] + return f + + +Registered +^^^^^^^^^^ + +.. code-block:: python + + from __future__ import annotations + + from decimal import Decimal + from typing import TYPE_CHECKING + + from airflow.utils.module_loading import qualname + + if TYPE_CHECKING: + from airflow.serialization.serde import U + + + serializers = [ + Decimal + ] # this can be a type or a fully qualified str. Str can be used to prevent circular imports + deserializers = serializers # in some cases you might not have a deserializer (e.g. k8s pod) + + __version__ = 1 # required + + # the serializer expects output, classname, version, is_serialized? + def serialize(o: object) -> tuple[U, str, int, bool]: + if isinstance(o, Decimal): + name = qualname(o) + _, _, exponent = o.as_tuple() + if exponent >= 0: # No digits after the decimal point. + return int(o), name, __version__, True + # Technically lossy due to floating point errors, but the best we + # can do without implementing a custom encode function. + return float(o), name, __version__, True + + return "", "", 0, False + + + # the deserializer sanitizes the data for you, so you do not need to deserialize values yourself + def deserialize(classname: str, version: int, data: object) -> Decimal: + # always check version compatibility + if version > __version__: + raise TypeError(f"serialized {version} of {classname} > {__version__}") + + if classname != qualname(Decimal): + raise TypeError(f"{classname} != {qualname(Decimal)}") + + return Decimal(str(data)) diff --git a/docs/apache-airflow/integration.rst b/docs/apache-airflow/integration.rst index a628f95dec6a6..3546e5f98a4a5 100644 --- a/docs/apache-airflow/integration.rst +++ b/docs/apache-airflow/integration.rst @@ -32,6 +32,7 @@ Airflow has a mechanism that allows you to expand its functionality and integrat * :doc:`Secrets backends ` * :doc:`Tracking systems ` * :doc:`Web UI Authentication backends ` +* :doc:`Serialization ` It also has integration with :doc:`Sentry ` service for error tracking. Other applications can also integrate using the :doc:`REST API `. diff --git a/tests/serialization/serializers/__init__.py b/tests/serialization/serializers/__init__.py new file mode 100644 index 0000000000000..217e5db960782 --- /dev/null +++ b/tests/serialization/serializers/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/tests/serialization/serializers/test_serializers.py b/tests/serialization/serializers/test_serializers.py new file mode 100644 index 0000000000000..e4f146b9d9502 --- /dev/null +++ b/tests/serialization/serializers/test_serializers.py @@ -0,0 +1,93 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import datetime +import decimal + +import numpy +import pendulum.tz +import pytest +from pendulum import DateTime + +from airflow.models.param import Param, ParamsDict +from airflow.serialization.serde import DATA, deserialize, serialize + + +class TestSerializers: + def test_datetime(self): + i = datetime.datetime(2022, 7, 10, 22, 10, 43, microsecond=0, tzinfo=pendulum.tz.UTC) + + s = serialize(i) + d = deserialize(s) + assert i.timestamp() == d.timestamp() + + i = DateTime(2022, 7, 10, tzinfo=pendulum.tz.UTC) + s = serialize(i) + d = deserialize(s) + assert i.timestamp() == d.timestamp() + + i = datetime.date(2022, 7, 10) + s = serialize(i) + d = deserialize(s) + assert i == d + + i = datetime.timedelta(days=320) + s = serialize(i) + d = deserialize(s) + assert i == d + + @pytest.mark.parametrize( + "expr, expected", + [("1", "1"), ("52e4", "520000"), ("2e0", "2"), ("12e-2", "0.12"), ("12.34", "12.34")], + ) + def test_encode_decimal(self, expr, expected): + assert deserialize(serialize(decimal.Decimal(expr))) == decimal.Decimal(expected) + + def test_encode_k8s_v1pod(self): + from kubernetes.client import models as k8s + + pod = k8s.V1Pod( + metadata=k8s.V1ObjectMeta( + name="foo", + namespace="bar", + ), + spec=k8s.V1PodSpec( + containers=[ + k8s.V1Container( + name="foo", + image="bar", + ) + ] + ), + ) + assert serialize(pod)[DATA] == { + "metadata": {"name": "foo", "namespace": "bar"}, + "spec": {"containers": [{"image": "bar", "name": "foo"}]}, + } + + def test_numpy(self): + i = numpy.int16(10) + e = serialize(i) + d = deserialize(e) + assert i == d + + def test_params(self): + i = ParamsDict({"x": Param(default="value", description="there is a value", key="test")}) + e = serialize(i) + d = deserialize(e) + assert i["x"] == d["x"] diff --git a/tests/serialization/test_serde.py b/tests/serialization/test_serde.py new file mode 100644 index 0000000000000..8a00ed5cf0411 --- /dev/null +++ b/tests/serialization/test_serde.py @@ -0,0 +1,223 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import datetime +import enum +from dataclasses import dataclass +from typing import ClassVar + +import attr +import pytest + +from airflow.datasets import Dataset +from airflow.serialization.serde import ( + CLASSNAME, + DATA, + SCHEMA_ID, + VERSION, + _compile_patterns, + deserialize, + serialize, +) +from airflow.utils.module_loading import qualname +from tests.test_utils.config import conf_vars + + +class Z: + __version__: ClassVar[int] = 1 + + def __init__(self, x): + self.x = x + + def serialize(self) -> dict: + return dict({"x": self.x}) + + @staticmethod + def deserialize(data: dict, version: int): + if version != 1: + raise TypeError("version != 1") + return Z(data["x"]) + + +@attr.define +class Y: + x: int + __version__: ClassVar[int] = 1 + + def __init__(self, x): + self.x = x + + +class X: + pass + + +@dataclass +class W: + __version__: ClassVar[int] = 2 + x: int + + +class TestSerDe: + @pytest.fixture(autouse=True) + def ensure_clean_allow_list(self): + _compile_patterns() + yield + + def test_ser_primitives(self): + i = 10 + e = serialize(i) + assert i == e + + i = 10.1 + e = serialize(i) + assert i == e + + i = "test" + e = serialize(i) + assert i == e + + i = True + e = serialize(i) + assert i == e + + Color = enum.IntEnum("Color", ["RED", "GREEN"]) + i = Color.RED + e = serialize(i) + assert i == e + + def test_ser_iterables(self): + i = [1, 2] + e = serialize(i) + assert i == e + + i = ("a", "b", "a", "c") + e = serialize(i) + assert i == e + + i = {2, 3} + e = serialize(i) + assert i == e + + def test_ser_plain_dict(self): + i = {"a": 1, "b": 2} + e = serialize(i) + assert i == e + + with pytest.raises(AttributeError, match="^reserved"): + i = {CLASSNAME: "cannot"} + serialize(i) + + with pytest.raises(AttributeError, match="^reserved"): + i = {SCHEMA_ID: "cannot"} + serialize(i) + + def test_no_serializer(self): + with pytest.raises(TypeError, match="^cannot serialize"): + i = Exception + serialize(i) + + def test_ser_registered(self): + i = datetime.datetime(2000, 10, 1) + e = serialize(i) + assert e[DATA] + + def test_serder_custom(self): + i = Z(1) + e = serialize(i) + assert Z.__version__ == e[VERSION] + assert qualname(Z) == e[CLASSNAME] + assert e[DATA] + + d = deserialize(e) + assert i.x == getattr(d, "x", None) + + def test_serder_attr(self): + i = Y(10) + e = serialize(i) + assert Y.__version__ == e[VERSION] + assert qualname(Y) == e[CLASSNAME] + assert e[DATA] + + d = deserialize(e) + assert i.x == getattr(d, "x", None) + + def test_serder_dataclass(self): + i = W(12) + e = serialize(i) + assert W.__version__ == e[VERSION] + assert qualname(W) == e[CLASSNAME] + assert e[DATA] + + d = deserialize(e) + assert i.x == getattr(d, "x", None) + + @conf_vars( + { + ("core", "allowed_deserialization_classes"): "airflow[.].*", + } + ) + def test_allow_list_for_imports(self): + _compile_patterns() + i = Z(10) + e = serialize(i) + with pytest.raises(ImportError) as ex: + deserialize(e) + + assert f"{qualname(Z)} was not found in allow list" in str(ex.value) + + def test_incompatible_version(self): + data = dict( + { + "__classname__": Y.__module__ + "." + Y.__qualname__, + "__version__": 2, + } + ) + with pytest.raises(TypeError, match="newer than"): + deserialize(data) + + def test_raise_undeserializable(self): + data = dict( + { + "__classname__": X.__module__ + "." + X.__qualname__, + "__version__": 0, + } + ) + with pytest.raises(TypeError, match="No deserializer"): + deserialize(data) + + def test_backwards_compat(self): + uri = "s3://does_not_exist" + data = { + "__type": "airflow.datasets.Dataset", + "__source": None, + "__var": { + "__var": { + "uri": uri, + "extra": None, + }, + "__type": "dict", + }, + } + dataset = deserialize(data) + assert dataset.uri == uri + + def test_encode_dataset(self): + dataset = Dataset("mytest://dataset") + obj = deserialize(serialize(dataset)) + assert dataset.uri == obj.uri diff --git a/tests/utils/test_json.py b/tests/utils/test_json.py index 2376b5ce2fd36..bcf11dff10990 100644 --- a/tests/utils/test_json.py +++ b/tests/utils/test_json.py @@ -17,112 +17,53 @@ # under the License. from __future__ import annotations -import decimal import json from dataclasses import dataclass from datetime import date, datetime from typing import ClassVar -import attr import numpy as np import pendulum import pytest from airflow.datasets import Dataset from airflow.utils import json as utils_json -from tests.test_utils.config import conf_vars - - -class Z: - version = 1 - - def __init__(self, x): - self.x = x - - def serialize(self) -> dict: - return dict({"x": self.x}) - - @staticmethod - def deserialize(data: dict, version: int): - if version != 1: - raise TypeError("version != 1") - return Z(data["x"]) - - -@attr.define -class Y: - x: int - version: ClassVar[int] = 1 - - def __init__(self, x): - self.x = x - - -class X: - pass @dataclass class U: - version: ClassVar[int] = 2 + __version__: ClassVar[int] = 2 x: int -class TestAirflowJsonEncoder: +class TestWebEncoder: def test_encode_datetime(self): obj = datetime.strptime("2017-05-21 00:00:00", "%Y-%m-%d %H:%M:%S") - assert json.dumps(obj, cls=utils_json.AirflowJsonEncoder) == '"2017-05-21T00:00:00+00:00"' + assert json.dumps(obj, cls=utils_json.WebEncoder) == '"2017-05-21T00:00:00+00:00"' def test_encode_pendulum(self): obj = pendulum.datetime(2017, 5, 21, tz="Asia/Kolkata") - assert json.dumps(obj, cls=utils_json.AirflowJsonEncoder) == '"2017-05-21T00:00:00+05:30"' + assert json.dumps(obj, cls=utils_json.WebEncoder) == '"2017-05-21T00:00:00+05:30"' def test_encode_date(self): - assert json.dumps(date(2017, 5, 21), cls=utils_json.AirflowJsonEncoder) == '"2017-05-21"' - - @pytest.mark.parametrize( - "expr, expected", - [("1", "1"), ("52e4", "520000"), ("2e0", "2"), ("12e-2", "0.12"), ("12.34", "12.34")], - ) - def test_encode_decimal(self, expr, expected): - assert json.dumps(decimal.Decimal(expr), cls=utils_json.AirflowJsonEncoder) == expected + assert json.dumps(date(2017, 5, 21), cls=utils_json.WebEncoder) == '"2017-05-21"' def test_encode_numpy_int(self): - assert json.dumps(np.int32(5), cls=utils_json.AirflowJsonEncoder) == "5" + assert json.dumps(np.int32(5), cls=utils_json.WebEncoder) == "5" def test_encode_numpy_bool(self): - assert json.dumps(np.bool_(True), cls=utils_json.AirflowJsonEncoder) == "true" + assert json.dumps(np.bool_(True), cls=utils_json.WebEncoder) == "true" def test_encode_numpy_float(self): - assert json.dumps(np.float16(3.76953125), cls=utils_json.AirflowJsonEncoder) == "3.76953125" - - def test_encode_k8s_v1pod(self): - from kubernetes.client import models as k8s - - pod = k8s.V1Pod( - metadata=k8s.V1ObjectMeta( - name="foo", - namespace="bar", - ), - spec=k8s.V1PodSpec( - containers=[ - k8s.V1Container( - name="foo", - image="bar", - ) - ] - ), - ) - assert json.loads(json.dumps(pod, cls=utils_json.AirflowJsonEncoder)) == { - "metadata": {"name": "foo", "namespace": "bar"}, - "spec": {"containers": [{"image": "bar", "name": "foo"}]}, - } + assert json.dumps(np.float16(3.76953125), cls=utils_json.WebEncoder) == "3.76953125" + +class TestXComEncoder: def test_encode_raises(self): with pytest.raises(TypeError, match="^.*is not JSON serializable$"): json.dumps( Exception, - cls=utils_json.AirflowJsonEncoder, + cls=utils_json.XComEncoder, ) def test_encode_xcom_dataset(self): @@ -131,103 +72,9 @@ def test_encode_xcom_dataset(self): obj = json.loads(s, cls=utils_json.XComDecoder) assert dataset.uri == obj.uri - def test_backwards_compat(self): - uri = "s3://does_not_exist" - data = { - "__type": "airflow.datasets.Dataset", - "__source": None, - "__var": { - "__var": { - "uri": uri, - "extra": None, - }, - "__type": "dict", - }, - } - decoder = utils_json.XComDecoder() - dataset = decoder.object_hook(data) - assert dataset.uri == uri - - def test_raise_on_reserved(self): - data = {"__classname__": "my.class"} - with pytest.raises(AttributeError): - json.dumps(data, cls=utils_json.XComEncoder) - - def test_custom_serialize(self): - x = 11 - z = Z(x) - s = json.dumps(z, cls=utils_json.XComEncoder) - o = json.loads(s, cls=utils_json.XComDecoder) - - assert o.x == x - - def test_raise_undeserializable(self): - data = dict( - { - "__classname__": X.__module__ + "." + X.__qualname__, - "__version__": 0, - } - ) - s = json.dumps(data) - with pytest.raises(TypeError) as info: - json.loads(s, cls=utils_json.XComDecoder) - - assert "cannot deserialize" in str(info.value) - - def test_attr_version(self): - x = 14 - y = Y(x) - s = json.dumps(y, cls=utils_json.XComEncoder) - o = json.loads(s, cls=utils_json.XComDecoder) - - assert o.x == x - - def test_incompatible_version(self): - data = dict( - { - "__classname__": Y.__module__ + "." + Y.__qualname__, - "__version__": 2, - } - ) - s = json.dumps(data) - with pytest.raises(TypeError) as info: - json.loads(s, cls=utils_json.XComDecoder) - - assert "newer than" in str(info.value) - - def test_dataclass(self): - x = 12 - u = U(x=x) - s = json.dumps(u, cls=utils_json.XComEncoder) - o = json.loads(s, cls=utils_json.XComDecoder) - - assert o.x == x - def test_orm_deserialize(self): x = 14 u = U(x=x) s = json.dumps(u, cls=utils_json.XComEncoder) o = json.loads(s, cls=utils_json.XComDecoder, object_hook=utils_json.XComDecoder.orm_object_hook) - assert o == f"{U.__module__}.{U.__qualname__}@version={U.version}(x={x})" - - def test_orm_custom_deserialize(self): - x = 14 - z = Z(x=x) - s = json.dumps(z, cls=utils_json.XComEncoder) - o = json.loads(s, cls=utils_json.XComDecoder, object_hook=utils_json.XComDecoder.orm_object_hook) - assert o == f"{Z.__module__}.{Z.__qualname__}@version={Z.version}(x={x})" - - @conf_vars( - { - ("core", "allowed_deserialization_classes"): "airflow[.].*", - } - ) - def test_allow_list_for_imports(self): - x = 14 - z = Z(x=x) - s = json.dumps(z, cls=utils_json.XComEncoder) - - with pytest.raises(ImportError) as e: - json.loads(s, cls=utils_json.XComDecoder) - - assert f"{Z.__module__}.{Z.__qualname__} was not found in allow list" in str(e.value) + assert o == f"{U.__module__}.{U.__qualname__}@version={U.__version__}(x={x})" From d99a9480859af705604e9f0a6a59c5939536c223 Mon Sep 17 00:00:00 2001 From: bolkedebruin Date: Wed, 7 Dec 2022 18:21:30 +0100 Subject: [PATCH 2/6] Update docs/apache-airflow/developer/serializers.rst Co-authored-by: Kaxil Naik --- docs/apache-airflow/developer/serializers.rst | 22 +++++++++---------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/docs/apache-airflow/developer/serializers.rst b/docs/apache-airflow/developer/serializers.rst index 69fcd8a155e18..80aa8219ee39f 100644 --- a/docs/apache-airflow/developer/serializers.rst +++ b/docs/apache-airflow/developer/serializers.rst @@ -1,19 +1,19 @@ .. Licensed to the Apache Software Foundation (ASF) under one -or more contributor license agreements. See the NOTICE file -distributed with this work for additional information -regarding copyright ownership. The ASF licenses this file -to you under the Apache License, Version 2.0 (the -"License"); you may not use this file except in compliance -with the License. You may obtain a copy of the License at + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at .. http://www.apache.org/licenses/LICENSE-2.0 .. Unless required by applicable law or agreed to in writing, -software distributed under the License is distributed on an -"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -KIND, either express or implied. See the License for the -specific language governing permissions and limitations -under the License. + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. Serialization ============= From ae91f354bd04a7258202a7851a925589a0e09640 Mon Sep 17 00:00:00 2001 From: Bolke de Bruin Date: Wed, 7 Dec 2022 18:30:46 +0100 Subject: [PATCH 3/6] Add backwards compatibility --- airflow/utils/json.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/airflow/utils/json.py b/airflow/utils/json.py index 10806c12bffe4..0a497aab898e2 100644 --- a/airflow/utils/json.py +++ b/airflow/utils/json.py @@ -111,3 +111,7 @@ def object_hook(self, dct: dict) -> object: def orm_object_hook(dct: dict) -> object: """Creates a readable representation of a serialized object""" return deserialize(dct, False) + + +# backwards compatibility +AirflowJsonEncoder = WebEncoder From 71358a4685e392bc99a54c5157348c853346aa47 Mon Sep 17 00:00:00 2001 From: Bolke de Bruin Date: Thu, 8 Dec 2022 11:08:01 +0100 Subject: [PATCH 4/6] Move docs --- docs/apache-airflow/concepts/index.rst | 1 + docs/apache-airflow/{developer => concepts}/serializers.rst | 0 docs/apache-airflow/integration.rst | 2 +- 3 files changed, 2 insertions(+), 1 deletion(-) rename docs/apache-airflow/{developer => concepts}/serializers.rst (100%) diff --git a/docs/apache-airflow/concepts/index.rst b/docs/apache-airflow/concepts/index.rst index 6663fc004b333..61704dacc1afa 100644 --- a/docs/apache-airflow/concepts/index.rst +++ b/docs/apache-airflow/concepts/index.rst @@ -48,6 +48,7 @@ Here you can find detailed documentation about each one of Airflow's core concep timetable priority-weight cluster-policies + serializers **Communication** diff --git a/docs/apache-airflow/developer/serializers.rst b/docs/apache-airflow/concepts/serializers.rst similarity index 100% rename from docs/apache-airflow/developer/serializers.rst rename to docs/apache-airflow/concepts/serializers.rst diff --git a/docs/apache-airflow/integration.rst b/docs/apache-airflow/integration.rst index 3546e5f98a4a5..e77f72ea5af3c 100644 --- a/docs/apache-airflow/integration.rst +++ b/docs/apache-airflow/integration.rst @@ -32,7 +32,7 @@ Airflow has a mechanism that allows you to expand its functionality and integrat * :doc:`Secrets backends ` * :doc:`Tracking systems ` * :doc:`Web UI Authentication backends ` -* :doc:`Serialization ` +* :doc:`Serialization ` It also has integration with :doc:`Sentry ` service for error tracking. Other applications can also integrate using the :doc:`REST API `. From 1b95e6f61470631bc504f0d29e17918adbf4afa5 Mon Sep 17 00:00:00 2001 From: Bolke de Bruin Date: Thu, 8 Dec 2022 16:01:03 +0100 Subject: [PATCH 5/6] Add words --- docs/spelling_wordlist.txt | 3 +++ 1 file changed, 3 insertions(+) diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index 5b186ae821fb5..a702b903eb56d 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -409,6 +409,7 @@ desc deserialization Deserialize deserialize +deserializer Deserialized deserialized deserializing @@ -1276,6 +1277,8 @@ Sendgrid sendgrid serde serialise +serializer +serializers serializable SerializedDAG serverless From 194939a6df9d75e53be339bcf85bafdb7873862c Mon Sep 17 00:00:00 2001 From: Bolke de Bruin Date: Thu, 8 Dec 2022 23:00:38 +0100 Subject: [PATCH 6/6] last fixes --- airflow/providers/amazon/aws/hooks/eks.py | 10 ++++++---- airflow/providers/amazon/aws/operators/sagemaker.py | 4 ++-- docs/spelling_wordlist.txt | 6 +++--- 3 files changed, 11 insertions(+), 9 deletions(-) diff --git a/airflow/providers/amazon/aws/hooks/eks.py b/airflow/providers/amazon/aws/hooks/eks.py index d74b2a50536d9..03488dea96771 100644 --- a/airflow/providers/amazon/aws/hooks/eks.py +++ b/airflow/providers/amazon/aws/hooks/eks.py @@ -31,7 +31,7 @@ from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook from airflow.utils import yaml -from airflow.utils.json import WebEncoder +from airflow.utils.json import AirflowJsonEncoder DEFAULT_PAGINATION_TOKEN = "" STS_TOKEN_EXPIRES_IN = 60 @@ -276,7 +276,7 @@ def describe_cluster(self, name: str, verbose: bool = False) -> dict: ) if verbose: cluster_data = response.get("cluster") - self.log.info("Amazon EKS cluster details: %s", json.dumps(cluster_data, cls=WebEncoder)) + self.log.info("Amazon EKS cluster details: %s", json.dumps(cluster_data, cls=AirflowJsonEncoder)) return response def describe_nodegroup(self, clusterName: str, nodegroupName: str, verbose: bool = False) -> dict: @@ -302,7 +302,7 @@ def describe_nodegroup(self, clusterName: str, nodegroupName: str, verbose: bool nodegroup_data = response.get("nodegroup") self.log.info( "Amazon EKS managed node group details: %s", - json.dumps(nodegroup_data, cls=WebEncoder), + json.dumps(nodegroup_data, cls=AirflowJsonEncoder), ) return response @@ -331,7 +331,9 @@ def describe_fargate_profile( ) if verbose: fargate_profile_data = response.get("fargateProfile") - self.log.info("AWS Fargate profile details: %s", json.dumps(fargate_profile_data, cls=WebEncoder)) + self.log.info( + "AWS Fargate profile details: %s", json.dumps(fargate_profile_data, cls=AirflowJsonEncoder) + ) return response def get_cluster_state(self, clusterName: str) -> ClusterStates: diff --git a/airflow/providers/amazon/aws/operators/sagemaker.py b/airflow/providers/amazon/aws/operators/sagemaker.py index 3acba423eb373..4b969002b339e 100644 --- a/airflow/providers/amazon/aws/operators/sagemaker.py +++ b/airflow/providers/amazon/aws/operators/sagemaker.py @@ -26,7 +26,7 @@ from airflow.models import BaseOperator from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook -from airflow.utils.json import WebEncoder +from airflow.utils.json import AirflowJsonEncoder if TYPE_CHECKING: from airflow.utils.context import Context @@ -36,7 +36,7 @@ def serialize(result: dict) -> str: - return json.loads(json.dumps(result, cls=WebEncoder)) + return json.loads(json.dumps(result, cls=AirflowJsonEncoder)) class SageMakerBaseOperator(BaseOperator): diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index a702b903eb56d..e05912d2748a1 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -409,9 +409,9 @@ desc deserialization Deserialize deserialize -deserializer Deserialized deserialized +deserializer deserializing dest dev @@ -1277,10 +1277,10 @@ Sendgrid sendgrid serde serialise -serializer -serializers serializable SerializedDAG +serializer +serializers serverless ServiceAccount setattr