Skip to content
51 changes: 49 additions & 2 deletions airflow/models/dagrun.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from __future__ import annotations

import itertools
import json
import os
import warnings
from collections import defaultdict
Expand Down Expand Up @@ -96,6 +97,37 @@ class TISchedulingDecision(NamedTuple):
finished_tis: list[TI]


class ConfDict(dict):
"""Custom dictionary for storing only JSON serializable values."""

def __init__(self, val=None):
super().__init__(self.is_jsonable(val))

def __setitem__(self, key, value):
self.is_jsonable({key: value})
super().__setitem__(key, value)

@staticmethod
def is_jsonable(conf: dict) -> dict | None:
"""Prevent setting non-json attributes."""
try:
json.dumps(conf)
except TypeError:
raise AirflowException("Cannot assign non JSON Serializable value")
if isinstance(conf, dict):
return conf
else:
raise AirflowException(f"Object of type {type(conf)} must be a dict")

@staticmethod
def dump_check(conf: str) -> str:
val = json.loads(conf)
if isinstance(val, dict):
return conf
else:
raise TypeError(f"Object of type {type(val)} must be a dict")


def _creator_note(val):
"""Creator the ``note`` association proxy."""
if isinstance(val, str):
Expand Down Expand Up @@ -126,7 +158,7 @@ class DagRun(Base, LoggingMixin):
creating_job_id = Column(Integer)
external_trigger = Column(Boolean, default=True)
run_type = Column(String(50), nullable=False)
conf = Column(PickleType)
_conf = Column("conf", PickleType)
# These two must be either both NULL or both datetime.
data_interval_start = Column(UtcDateTime)
data_interval_end = Column(UtcDateTime)
Expand Down Expand Up @@ -228,7 +260,12 @@ def __init__(
self.execution_date = execution_date
self.start_date = start_date
self.external_trigger = external_trigger
self.conf = conf or {}

if isinstance(conf, str):
self._conf = ConfDict.dump_check(conf)
else:
self._conf = ConfDict(conf or {})

if state is not None:
self.state = state
if queued_at is NOTSET:
Expand Down Expand Up @@ -258,6 +295,16 @@ def validate_run_id(self, key: str, run_id: str) -> str | None:
)
return run_id

def get_conf(self):
return self._conf

def set_conf(self, value):
self._conf = ConfDict(value)
Comment thread
jscheffl marked this conversation as resolved.

@declared_attr
def conf(self):
return synonym("_conf", descriptor=property(self.get_conf, self.set_conf))

@property
def stats_tags(self) -> dict[str, str]:
return prune_dict({"dag_id": self.dag_id, "run_type": self.run_type})
Expand Down
14 changes: 14 additions & 0 deletions tests/models/test_dagrun.py
Original file line number Diff line number Diff line change
Expand Up @@ -2617,3 +2617,17 @@ def test_dag_run_id_config(session, dag_maker, pattern, run_id, result):
else:
with pytest.raises(AirflowException):
dag_maker.create_dagrun(run_id=run_id)


def test_dagrun_conf():
dag_run = DagRun(conf={"test": 1234})
assert dag_run.conf == {"test": 1234}

with pytest.raises(AirflowException) as err:
dag_run.conf["non_json"] = timezone.utcnow()
assert str(err.value) == "Cannot assign non JSON Serializable value"

with pytest.raises(AirflowException) as err:
value = 1
dag_run.conf = value
assert str(err.value) == f"Object of type {type(value)} must be a dict"