Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 26 additions & 4 deletions cosmos/operators/_watcher/xcom.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import base64
import json
import re
import zlib
from typing import Any

Expand All @@ -20,16 +21,37 @@

XCOM_BACKUP_VARIABLE_PREFIX = "cosmos_xcom_backup"

# Characters that secrets backends commonly reject in Variable keys. AWS
# Secrets Manager allows alphanumerics + ``-/_+=.@!``; GCP Secret Manager is
# stricter at ``[A-Za-z0-9_-]``. Run IDs routinely contain ``:`` and ``+``
# (timestamps and timezone offsets, e.g. ``scheduled__2026-05-04T10:15:00+00:00``)
# and dag/task-group IDs can contain ``.``. Sanitize all components down to
# ``[A-Za-z0-9_-]`` so the resulting key is portable across backends and does
# not log a ``ValidationException`` on every ``Variable.set`` call when an
# external secrets backend is configured (Airflow walks the backend chain on
# set as well as get).
_DISALLOWED_VARIABLE_KEY_CHAR_RE = re.compile(r"[^A-Za-z0-9_-]")


def _xcom_backup_variable_key(dag_id: str, task_group_id: str | None, run_id: str) -> str:
"""Build a unique Airflow Variable key for the XCom backup of a watcher producer run."""
parts = [XCOM_BACKUP_VARIABLE_PREFIX, dag_id.replace(".", "___")]
"""Build a unique Airflow Variable key for the XCom backup of a watcher producer run.

The component-specific period-replacement counts (3 underscores for dag_id,
2 for task_group_id, 1 for run_id) are preserved so keys remain visually
parseable, then any remaining disallowed character is normalized to ``_``.
"""
parts = [XCOM_BACKUP_VARIABLE_PREFIX, _sanitize_key_component(dag_id.replace(".", "___"))]
if task_group_id:
parts.append(task_group_id.replace(".", "__"))
parts.append(run_id.replace(".", "_"))
parts.append(_sanitize_key_component(task_group_id.replace(".", "__")))
parts.append(_sanitize_key_component(run_id.replace(".", "_")))
return "__".join(parts)


def _sanitize_key_component(value: str) -> str:
"""Replace characters that secrets backends reject in Variable keys with ``_``."""
return _DISALLOWED_VARIABLE_KEY_CHAR_RE.sub("_", value)


def _get_task_group_id(ti: Any) -> str | None:
"""Extract the task_group_id from a task instance, if available."""
task = getattr(ti, "task", None)
Expand Down
35 changes: 35 additions & 0 deletions tests/operators/_watcher/test_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
_init_xcom_backup,
_persist_backup,
_restore_xcom_from_variable,
_xcom_backup_variable_key,
)


Expand Down Expand Up @@ -124,6 +125,40 @@ def xcom_push(self, key, value, **_):
self.store[key] = value


class TestXcomBackupVariableKey:
"""Tests for ``_xcom_backup_variable_key`` covering secrets-backend
compatibility (AWS Secrets Manager and similar reject ``:``, ``+`` etc.)."""

AWS_ALLOWED = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789-_"

def test_period_replacement_preserved(self):
key = _xcom_backup_variable_key("a.b", "g.h", "r.s")
assert key == "cosmos_xcom_backup__a___b__g__h__r_s"

def test_run_id_with_colon_and_plus_is_sanitized(self):
key = _xcom_backup_variable_key(
dag_id="dbt_daily",
task_group_id=None,
run_id="scheduled__2026-05-04T10:15:00+00:00",
)
assert key == "cosmos_xcom_backup__dbt_daily__scheduled__2026-05-04T10_15_00_00_00"
assert ":" not in key and "+" not in key

@pytest.mark.parametrize(
"dag_id,task_group_id,run_id",
[
("dbt_daily", None, "scheduled__2026-05-04T10:15:00+00:00"),
("dag.with.dots", "group.id", "manual__2026-01-01T00:00:00+00:00"),
("dag with spaces", None, "manual__2026-01-01"),
("dag(parens)", None, "run/with/slashes"),
("dag*star", "group:colon", "run+plus"),
],
)
def test_result_contains_only_safe_characters(self, dag_id, task_group_id, run_id):
key = _xcom_backup_variable_key(dag_id, task_group_id, run_id)
assert all(c in self.AWS_ALLOWED for c in key), key


class TestInitXcomBackup:
def test_sets_var_key_and_buffer_on_ti(self):
ti = _MockTI()
Expand Down
Loading