Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 24 additions & 1 deletion airflow/models/param.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import json
import logging
import warnings
from typing import TYPE_CHECKING, Any, ItemsView, Iterable, MutableMapping, ValuesView
from typing import TYPE_CHECKING, Any, ClassVar, ItemsView, Iterable, MutableMapping, ValuesView

from airflow.exceptions import AirflowException, ParamValidationError, RemovedInAirflow3Warning
from airflow.utils.context import Context
Expand All @@ -47,6 +47,8 @@ class Param:
default & description will form the schema
"""

__version__: ClassVar[int] = 1

CLASS_IDENTIFIER = "__class"

def __init__(self, default: Any = NOTSET, description: str | None = None, **kwargs):
Expand Down Expand Up @@ -112,6 +114,16 @@ def dump(self) -> dict:
def has_value(self) -> bool:
return self.value is not NOTSET

def serialize(self) -> dict:
return {"value": self.value, "description": self.description, "schema": self.schema}

@staticmethod
def deserialize(data: dict[str, Any], version: int) -> Param:
if version > Param.__version__:
raise TypeError("serialized version > class version")

return Param(default=data["value"], description=data["description"], schema=data["schema"])


class ParamsDict(MutableMapping[str, Any]):
"""
Expand All @@ -120,6 +132,7 @@ class ParamsDict(MutableMapping[str, Any]):
dictionary implicitly and ideally not needed to be used directly.
"""

__version__: ClassVar[int] = 1
__slots__ = ["__dict", "suppress_exception"]

def __init__(self, dict_obj: dict | None = None, suppress_exception: bool = False):
Expand Down Expand Up @@ -231,6 +244,16 @@ def validate(self) -> dict[str, Any]:

return resolved_dict

def serialize(self) -> dict[str, Any]:
return self.dump()

@staticmethod
def deserialize(data: dict, version: int) -> ParamsDict:
if version > ParamsDict.__version__:
raise TypeError("serialized version > class version")

return ParamsDict(data)


class DagParam(ResolveMixin):
"""DAG run parameter reference.
Expand Down
21 changes: 8 additions & 13 deletions airflow/plugins_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
from airflow import settings
from airflow.utils.entry_points import entry_points_with_dist
from airflow.utils.file import find_path_from_directory
from airflow.utils.module_loading import as_importable_string
from airflow.utils.module_loading import qualname

if TYPE_CHECKING:
from airflow.hooks.base import BaseHook
Expand Down Expand Up @@ -373,7 +373,7 @@ def initialize_ti_deps_plugins():

for plugin in plugins:
registered_ti_dep_classes.update(
{as_importable_string(ti_dep.__class__): ti_dep.__class__ for ti_dep in plugin.ti_deps}
{qualname(ti_dep.__class__): ti_dep.__class__ for ti_dep in plugin.ti_deps}
)


Expand Down Expand Up @@ -406,7 +406,7 @@ def initialize_extra_operators_links_plugins():
operator_extra_links.extend(list(plugin.operator_extra_links))

registered_operator_link_classes.update(
{as_importable_string(link.__class__): link.__class__ for link in plugin.operator_extra_links}
{qualname(link.__class__): link.__class__ for link in plugin.operator_extra_links}
)


Expand All @@ -425,7 +425,7 @@ def initialize_timetables_plugins():
log.debug("Initialize extra timetables plugins")

timetable_classes = {
as_importable_string(timetable_class): timetable_class
qualname(timetable_class): timetable_class
for plugin in plugins
for timetable_class in plugin.timetables
}
Expand Down Expand Up @@ -525,25 +525,20 @@ def get_plugin_info(attrs_to_dump: Iterable[str] | None = None) -> list[dict[str
info: dict[str, Any] = {"name": plugin.name}
for attr in attrs_to_dump:
if attr in ("global_operator_extra_links", "operator_extra_links"):
info[attr] = [
f"<{as_importable_string(d.__class__)} object>" for d in getattr(plugin, attr)
]
info[attr] = [f"<{qualname(d.__class__)} object>" for d in getattr(plugin, attr)]
elif attr in ("macros", "timetables", "hooks", "executors"):
info[attr] = [as_importable_string(d) for d in getattr(plugin, attr)]
info[attr] = [qualname(d) for d in getattr(plugin, attr)]
elif attr == "listeners":
# listeners are always modules
info[attr] = [d.__name__ for d in getattr(plugin, attr)]
elif attr == "appbuilder_views":
info[attr] = [
{**d, "view": as_importable_string(d["view"].__class__) if "view" in d else None}
{**d, "view": qualname(d["view"].__class__) if "view" in d else None}
for d in getattr(plugin, attr)
]
elif attr == "flask_blueprints":
info[attr] = [
(
f"<{as_importable_string(d.__class__)}: "
f"name={d.name!r} import_name={d.import_name!r}>"
)
f"<{qualname(d.__class__)}: name={d.name!r} import_name={d.import_name!r}>"
for d in getattr(plugin, attr)
]
else:
Expand Down
Loading