diff --git a/airflow/config_templates/config.yml b/airflow/config_templates/config.yml index eef9fa0381bdd..139c1c35f54d6 100644 --- a/airflow/config_templates/config.yml +++ b/airflow/config_templates/config.yml @@ -210,6 +210,15 @@ example: ~ default: "False" see_also: "https://docs.python.org/3/library/pickle.html#comparison-with-json" + - name: allowed_deserialization_classes + description: | + What classes can be imported during deserialization. This is a multi line value. + The individual items will be parsed as regexp. Python built-in classes (like dict) + are always allowed + version_added: 2.5.0 + type: string + default: 'airflow\..*' + example: ~ - name: killed_task_cleanup_time description: | When a task is killed forcefully, this is the amount of time in seconds that diff --git a/airflow/config_templates/default_airflow.cfg b/airflow/config_templates/default_airflow.cfg index d1f7069cbbdbc..badf0ef3b2231 100644 --- a/airflow/config_templates/default_airflow.cfg +++ b/airflow/config_templates/default_airflow.cfg @@ -127,6 +127,11 @@ unit_test_mode = False # RCE exploits). enable_xcom_pickling = False +# What classes can be imported during deserialization. This is a multi line value. +# The individual items will be parsed as regexp. Python built-in classes (like dict) +# are always allowed +allowed_deserialization_classes = airflow\..* + # When a task is killed forcefully, this is the amount of time in seconds that # it has to cleanup after it is sent a SIGTERM, before it is SIGKILLED killed_task_cleanup_time = 60 diff --git a/airflow/config_templates/default_test.cfg b/airflow/config_templates/default_test.cfg index ed5f0d372342c..523f52cb69a04 100644 --- a/airflow/config_templates/default_test.cfg +++ b/airflow/config_templates/default_test.cfg @@ -35,6 +35,8 @@ plugins_folder = {TEST_PLUGINS_FOLDER} dags_are_paused_at_creation = False fernet_key = {FERNET_KEY} killed_task_cleanup_time = 5 +allowed_deserialization_classes = airflow\..* + tests\..* [database] sql_alchemy_conn = sqlite:///{AIRFLOW_HOME}/unittests.db diff --git a/airflow/utils/json.py b/airflow/utils/json.py index 36ed242500b5d..16f36cb475f58 100644 --- a/airflow/utils/json.py +++ b/airflow/utils/json.py @@ -20,6 +20,7 @@ import dataclasses import json import logging +import re from datetime import date, datetime from decimal import Decimal from typing import Any @@ -27,6 +28,7 @@ import attr from flask.json.provider import JSONProvider +from airflow.configuration import conf from airflow.serialization.enums import Encoding from airflow.utils.module_loading import import_string from airflow.utils.timezone import convert_to_utc, is_naive @@ -181,20 +183,36 @@ class XComDecoder(json.JSONDecoder): as is. """ + _pattern: list[re.Pattern] = [] + def __init__(self, *args, **kwargs) -> None: if not kwargs.get("object_hook"): kwargs["object_hook"] = self.object_hook + patterns = conf.get("core", "allowed_deserialization_classes").split() + + self._pattern.clear() # ensure to reinit + for p in patterns: + self._pattern.append(re.compile(p)) + super().__init__(*args, **kwargs) - @staticmethod - def object_hook(dct: dict) -> object: + def object_hook(self, dct: dict) -> object: dct = XComDecoder._convert(dct) if CLASSNAME in dct and VERSION in dct: from airflow.serialization.serialized_objects import BaseSerialization - cls = import_string(dct[CLASSNAME]) + classname = dct[CLASSNAME] + cls = None + + for p in self._pattern: + if p.match(classname): + cls = import_string(classname) + break + + if not cls: + raise ImportError(f"{classname} was not found in allow list for import") if hasattr(cls, "deserialize"): return getattr(cls, "deserialize")(dct[DATA], dct[VERSION]) diff --git a/tests/utils/test_json.py b/tests/utils/test_json.py index a2f7fc2ee53e8..2376b5ce2fd36 100644 --- a/tests/utils/test_json.py +++ b/tests/utils/test_json.py @@ -30,6 +30,7 @@ from airflow.datasets import Dataset from airflow.utils import json as utils_json +from tests.test_utils.config import conf_vars class Z: @@ -215,3 +216,18 @@ def test_orm_custom_deserialize(self): s = json.dumps(z, cls=utils_json.XComEncoder) o = json.loads(s, cls=utils_json.XComDecoder, object_hook=utils_json.XComDecoder.orm_object_hook) assert o == f"{Z.__module__}.{Z.__qualname__}@version={Z.version}(x={x})" + + @conf_vars( + { + ("core", "allowed_deserialization_classes"): "airflow[.].*", + } + ) + def test_allow_list_for_imports(self): + x = 14 + z = Z(x=x) + s = json.dumps(z, cls=utils_json.XComEncoder) + + with pytest.raises(ImportError) as e: + json.loads(s, cls=utils_json.XComDecoder) + + assert f"{Z.__module__}.{Z.__qualname__} was not found in allow list" in str(e.value)