Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions airflow/config_templates/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,15 @@
example: ~
default: "False"
see_also: "https://docs.python.org/3/library/pickle.html#comparison-with-json"
- name: allowed_deserialization_classes
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This only affects xcom right now. Is the goal to eventually make it affect serialised dags too, or should we change the config name?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd like to make it the generic serializer for Airflow, so eventually also DAGs (pending feasibility)

description: |
What classes can be imported during deserialization. This is a multi line value.
The individual items will be parsed as regexp. Python built-in classes (like dict)
are always allowed
version_added: 2.5.0
type: string
default: 'airflow\..*'
example: ~
- name: killed_task_cleanup_time
description: |
When a task is killed forcefully, this is the amount of time in seconds that
Expand Down
5 changes: 5 additions & 0 deletions airflow/config_templates/default_airflow.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,11 @@ unit_test_mode = False
# RCE exploits).
enable_xcom_pickling = False

# What classes can be imported during deserialization. This is a multi line value.
# The individual items will be parsed as regexp. Python built-in classes (like dict)
# are always allowed
allowed_deserialization_classes = airflow\..*

# When a task is killed forcefully, this is the amount of time in seconds that
# it has to cleanup after it is sent a SIGTERM, before it is SIGKILLED
killed_task_cleanup_time = 60
Expand Down
2 changes: 2 additions & 0 deletions airflow/config_templates/default_test.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ plugins_folder = {TEST_PLUGINS_FOLDER}
dags_are_paused_at_creation = False
fernet_key = {FERNET_KEY}
killed_task_cleanup_time = 5
allowed_deserialization_classes = airflow\..*
tests\..*

[database]
sql_alchemy_conn = sqlite:///{AIRFLOW_HOME}/unittests.db
Expand Down
24 changes: 21 additions & 3 deletions airflow/utils/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,15 @@
import dataclasses
import json
import logging
import re
from datetime import date, datetime
from decimal import Decimal
from typing import Any

import attr
from flask.json.provider import JSONProvider

from airflow.configuration import conf
from airflow.serialization.enums import Encoding
from airflow.utils.module_loading import import_string
from airflow.utils.timezone import convert_to_utc, is_naive
Expand Down Expand Up @@ -181,20 +183,36 @@ class XComDecoder(json.JSONDecoder):
as is.
"""

_pattern: list[re.Pattern] = []

def __init__(self, *args, **kwargs) -> None:
if not kwargs.get("object_hook"):
kwargs["object_hook"] = self.object_hook

patterns = conf.get("core", "allowed_deserialization_classes").split()

self._pattern.clear() # ensure to reinit
for p in patterns:
self._pattern.append(re.compile(p))

super().__init__(*args, **kwargs)

@staticmethod
def object_hook(dct: dict) -> object:
def object_hook(self, dct: dict) -> object:
dct = XComDecoder._convert(dct)

if CLASSNAME in dct and VERSION in dct:
from airflow.serialization.serialized_objects import BaseSerialization

cls = import_string(dct[CLASSNAME])
classname = dct[CLASSNAME]
cls = None

for p in self._pattern:
if p.match(classname):
cls = import_string(classname)
break

if not cls:
raise ImportError(f"{classname} was not found in allow list for import")

if hasattr(cls, "deserialize"):
return getattr(cls, "deserialize")(dct[DATA], dct[VERSION])
Expand Down
16 changes: 16 additions & 0 deletions tests/utils/test_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@

from airflow.datasets import Dataset
from airflow.utils import json as utils_json
from tests.test_utils.config import conf_vars


class Z:
Expand Down Expand Up @@ -215,3 +216,18 @@ def test_orm_custom_deserialize(self):
s = json.dumps(z, cls=utils_json.XComEncoder)
o = json.loads(s, cls=utils_json.XComDecoder, object_hook=utils_json.XComDecoder.orm_object_hook)
assert o == f"{Z.__module__}.{Z.__qualname__}@version={Z.version}(x={x})"

@conf_vars(
{
("core", "allowed_deserialization_classes"): "airflow[.].*",
}
)
def test_allow_list_for_imports(self):
x = 14
z = Z(x=x)
s = json.dumps(z, cls=utils_json.XComEncoder)

with pytest.raises(ImportError) as e:
json.loads(s, cls=utils_json.XComDecoder)

assert f"{Z.__module__}.{Z.__qualname__} was not found in allow list" in str(e.value)