From 2271cff5f3f4052726b6c3ba93da7a9864fdf5ae Mon Sep 17 00:00:00 2001 From: Bolke de Bruin Date: Thu, 24 Nov 2022 10:52:20 +0100 Subject: [PATCH 1/2] Add allow list for imports during deserialization During deserialization Airflow can instantiate arbitrary objects for which it imports modules. This can be dangerous as it could lead to unwanted effects. With this change administrators can now limit what objects can be deserialized. It defaults to Airflow's own only. --- airflow/config_templates/config.yml | 9 ++++++++ airflow/config_templates/default_airflow.cfg | 5 ++++ airflow/config_templates/default_test.cfg | 2 ++ airflow/utils/json.py | 24 +++++++++++++++++--- tests/utils/test_json.py | 16 +++++++++++++ 5 files changed, 53 insertions(+), 3 deletions(-) diff --git a/airflow/config_templates/config.yml b/airflow/config_templates/config.yml index eef9fa0381bdd..3e7adaad6ba9f 100644 --- a/airflow/config_templates/config.yml +++ b/airflow/config_templates/config.yml @@ -210,6 +210,15 @@ example: ~ default: "False" see_also: "https://docs.python.org/3/library/pickle.html#comparison-with-json" + - name: allowed_deserialization_classes + description: | + What classes can be imported during deserialization. This is a multiline value. + The individual items will be parsed as regexp. Python built-in classes (like dict) + are always allowed + version_added: 2.5.0 + type: string + default: 'airflow\..*' + example: ~ - name: killed_task_cleanup_time description: | When a task is killed forcefully, this is the amount of time in seconds that diff --git a/airflow/config_templates/default_airflow.cfg b/airflow/config_templates/default_airflow.cfg index d1f7069cbbdbc..3eeb904b6db27 100644 --- a/airflow/config_templates/default_airflow.cfg +++ b/airflow/config_templates/default_airflow.cfg @@ -127,6 +127,11 @@ unit_test_mode = False # RCE exploits). enable_xcom_pickling = False +# What classes can be imported during deserialization. This is a multiline value. +# The individual items will be parsed as regexp. Python built-in classes (like dict) +# are always allowed +allowed_deserialization_classes = airflow\..* + # When a task is killed forcefully, this is the amount of time in seconds that # it has to cleanup after it is sent a SIGTERM, before it is SIGKILLED killed_task_cleanup_time = 60 diff --git a/airflow/config_templates/default_test.cfg b/airflow/config_templates/default_test.cfg index ed5f0d372342c..523f52cb69a04 100644 --- a/airflow/config_templates/default_test.cfg +++ b/airflow/config_templates/default_test.cfg @@ -35,6 +35,8 @@ plugins_folder = {TEST_PLUGINS_FOLDER} dags_are_paused_at_creation = False fernet_key = {FERNET_KEY} killed_task_cleanup_time = 5 +allowed_deserialization_classes = airflow\..* + tests\..* [database] sql_alchemy_conn = sqlite:///{AIRFLOW_HOME}/unittests.db diff --git a/airflow/utils/json.py b/airflow/utils/json.py index 36ed242500b5d..16f36cb475f58 100644 --- a/airflow/utils/json.py +++ b/airflow/utils/json.py @@ -20,6 +20,7 @@ import dataclasses import json import logging +import re from datetime import date, datetime from decimal import Decimal from typing import Any @@ -27,6 +28,7 @@ import attr from flask.json.provider import JSONProvider +from airflow.configuration import conf from airflow.serialization.enums import Encoding from airflow.utils.module_loading import import_string from airflow.utils.timezone import convert_to_utc, is_naive @@ -181,20 +183,36 @@ class XComDecoder(json.JSONDecoder): as is. """ + _pattern: list[re.Pattern] = [] + def __init__(self, *args, **kwargs) -> None: if not kwargs.get("object_hook"): kwargs["object_hook"] = self.object_hook + patterns = conf.get("core", "allowed_deserialization_classes").split() + + self._pattern.clear() # ensure to reinit + for p in patterns: + self._pattern.append(re.compile(p)) + super().__init__(*args, **kwargs) - @staticmethod - def object_hook(dct: dict) -> object: + def object_hook(self, dct: dict) -> object: dct = XComDecoder._convert(dct) if CLASSNAME in dct and VERSION in dct: from airflow.serialization.serialized_objects import BaseSerialization - cls = import_string(dct[CLASSNAME]) + classname = dct[CLASSNAME] + cls = None + + for p in self._pattern: + if p.match(classname): + cls = import_string(classname) + break + + if not cls: + raise ImportError(f"{classname} was not found in allow list for import") if hasattr(cls, "deserialize"): return getattr(cls, "deserialize")(dct[DATA], dct[VERSION]) diff --git a/tests/utils/test_json.py b/tests/utils/test_json.py index a2f7fc2ee53e8..2376b5ce2fd36 100644 --- a/tests/utils/test_json.py +++ b/tests/utils/test_json.py @@ -30,6 +30,7 @@ from airflow.datasets import Dataset from airflow.utils import json as utils_json +from tests.test_utils.config import conf_vars class Z: @@ -215,3 +216,18 @@ def test_orm_custom_deserialize(self): s = json.dumps(z, cls=utils_json.XComEncoder) o = json.loads(s, cls=utils_json.XComDecoder, object_hook=utils_json.XComDecoder.orm_object_hook) assert o == f"{Z.__module__}.{Z.__qualname__}@version={Z.version}(x={x})" + + @conf_vars( + { + ("core", "allowed_deserialization_classes"): "airflow[.].*", + } + ) + def test_allow_list_for_imports(self): + x = 14 + z = Z(x=x) + s = json.dumps(z, cls=utils_json.XComEncoder) + + with pytest.raises(ImportError) as e: + json.loads(s, cls=utils_json.XComDecoder) + + assert f"{Z.__module__}.{Z.__qualname__} was not found in allow list" in str(e.value) From 75a69e67cbf17d567bfa580a40b140318622ed3a Mon Sep 17 00:00:00 2001 From: Bolke de Bruin Date: Fri, 25 Nov 2022 22:26:55 +0100 Subject: [PATCH 2/2] Grrr --- airflow/config_templates/config.yml | 2 +- airflow/config_templates/default_airflow.cfg | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/airflow/config_templates/config.yml b/airflow/config_templates/config.yml index 3e7adaad6ba9f..139c1c35f54d6 100644 --- a/airflow/config_templates/config.yml +++ b/airflow/config_templates/config.yml @@ -212,7 +212,7 @@ see_also: "https://docs.python.org/3/library/pickle.html#comparison-with-json" - name: allowed_deserialization_classes description: | - What classes can be imported during deserialization. This is a multiline value. + What classes can be imported during deserialization. This is a multi line value. The individual items will be parsed as regexp. Python built-in classes (like dict) are always allowed version_added: 2.5.0 diff --git a/airflow/config_templates/default_airflow.cfg b/airflow/config_templates/default_airflow.cfg index 3eeb904b6db27..badf0ef3b2231 100644 --- a/airflow/config_templates/default_airflow.cfg +++ b/airflow/config_templates/default_airflow.cfg @@ -127,7 +127,7 @@ unit_test_mode = False # RCE exploits). enable_xcom_pickling = False -# What classes can be imported during deserialization. This is a multiline value. +# What classes can be imported during deserialization. This is a multi line value. # The individual items will be parsed as regexp. Python built-in classes (like dict) # are always allowed allowed_deserialization_classes = airflow\..*