diff --git a/airflow/models/dagrun.py b/airflow/models/dagrun.py index a94683f6ed537..e38766cd4df21 100644 --- a/airflow/models/dagrun.py +++ b/airflow/models/dagrun.py @@ -18,6 +18,7 @@ from __future__ import annotations import itertools +import json import os import warnings from collections import defaultdict @@ -96,6 +97,37 @@ class TISchedulingDecision(NamedTuple): finished_tis: list[TI] +class ConfDict(dict): + """Custom dictionary for storing only JSON serializable values.""" + + def __init__(self, val=None): + super().__init__(self.is_jsonable(val)) + + def __setitem__(self, key, value): + self.is_jsonable({key: value}) + super().__setitem__(key, value) + + @staticmethod + def is_jsonable(conf: dict) -> dict | None: + """Prevent setting non-json attributes.""" + try: + json.dumps(conf) + except TypeError: + raise AirflowException("Cannot assign non JSON Serializable value") + if isinstance(conf, dict): + return conf + else: + raise AirflowException(f"Object of type {type(conf)} must be a dict") + + @staticmethod + def dump_check(conf: str) -> str: + val = json.loads(conf) + if isinstance(val, dict): + return conf + else: + raise TypeError(f"Object of type {type(val)} must be a dict") + + def _creator_note(val): """Creator the ``note`` association proxy.""" if isinstance(val, str): @@ -126,7 +158,7 @@ class DagRun(Base, LoggingMixin): creating_job_id = Column(Integer) external_trigger = Column(Boolean, default=True) run_type = Column(String(50), nullable=False) - conf = Column(PickleType) + _conf = Column("conf", PickleType) # These two must be either both NULL or both datetime. data_interval_start = Column(UtcDateTime) data_interval_end = Column(UtcDateTime) @@ -228,7 +260,12 @@ def __init__( self.execution_date = execution_date self.start_date = start_date self.external_trigger = external_trigger - self.conf = conf or {} + + if isinstance(conf, str): + self._conf = ConfDict.dump_check(conf) + else: + self._conf = ConfDict(conf or {}) + if state is not None: self.state = state if queued_at is NOTSET: @@ -258,6 +295,16 @@ def validate_run_id(self, key: str, run_id: str) -> str | None: ) return run_id + def get_conf(self): + return self._conf + + def set_conf(self, value): + self._conf = ConfDict(value) + + @declared_attr + def conf(self): + return synonym("_conf", descriptor=property(self.get_conf, self.set_conf)) + @property def stats_tags(self) -> dict[str, str]: return prune_dict({"dag_id": self.dag_id, "run_type": self.run_type}) diff --git a/tests/models/test_dagrun.py b/tests/models/test_dagrun.py index 9fc2ea86d472b..4a2ce9023b58f 100644 --- a/tests/models/test_dagrun.py +++ b/tests/models/test_dagrun.py @@ -2617,3 +2617,17 @@ def test_dag_run_id_config(session, dag_maker, pattern, run_id, result): else: with pytest.raises(AirflowException): dag_maker.create_dagrun(run_id=run_id) + + +def test_dagrun_conf(): + dag_run = DagRun(conf={"test": 1234}) + assert dag_run.conf == {"test": 1234} + + with pytest.raises(AirflowException) as err: + dag_run.conf["non_json"] = timezone.utcnow() + assert str(err.value) == "Cannot assign non JSON Serializable value" + + with pytest.raises(AirflowException) as err: + value = 1 + dag_run.conf = value + assert str(err.value) == f"Object of type {type(value)} must be a dict"