Skip to content

Commit

Permalink
Migrate pickled data & change XCom value type to JSON (#44166)
Browse files Browse the repository at this point in the history
follow-up of #43905

Changes:
- Changed `XCom.value` column to JSON for all dbs.
- Archived pickled XCom data to `_xcom_archive` and removed it from the `xcom` table.
- Removed encoded string in XCom serialization and deserialization logic.
- Updated logic for `XComObjectStorageBackend` to make it compatible for AF 2 & 3
  • Loading branch information
kaxil authored Nov 19, 2024
1 parent 39042c8 commit 86c4c6f
Show file tree
Hide file tree
Showing 11 changed files with 222 additions and 20 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

"""
Remove pickled data from xcom table.
Revision ID: eed27faa34e3
Revises: 9fc3fc5de720
Create Date: 2024-11-18 18:41:50.849514
"""

from __future__ import annotations

import sqlalchemy as sa
from alembic import op
from sqlalchemy import text
from sqlalchemy.dialects.mysql import LONGBLOB

from airflow.migrations.db_types import TIMESTAMP, StringID

revision = "eed27faa34e3"
down_revision = "9fc3fc5de720"
branch_labels = None
depends_on = None
airflow_version = "3.0.0"


def upgrade():
"""Apply Remove pickled data from xcom table."""
# Summary of the change:
# 1. Create an archived table (`_xcom_archive`) to store the current "pickled" data in the xcom table
# 2. Extract and archive the pickled data using the condition
# 3. Delete the pickled data from the xcom table so that we can update the column type
# 4. Update the XCom.value column type to JSON from LargeBinary/LongBlob

conn = op.get_bind()
dialect = conn.dialect.name

# Create an archived table to store the current data
op.create_table(
"_xcom_archive",
sa.Column("dag_run_id", sa.Integer(), nullable=False, primary_key=True),
sa.Column("task_id", StringID(length=250), nullable=False, primary_key=True),
sa.Column("map_index", sa.Integer(), nullable=False, server_default=sa.text("-1"), primary_key=True),
sa.Column("key", StringID(length=512), nullable=False, primary_key=True),
sa.Column("dag_id", StringID(length=250), nullable=False),
sa.Column("run_id", StringID(length=250), nullable=False),
sa.Column("value", sa.LargeBinary().with_variant(LONGBLOB, "mysql"), nullable=True),
sa.Column("timestamp", TIMESTAMP(), nullable=False),
sa.PrimaryKeyConstraint("dag_run_id", "task_id", "map_index", "key"),
if_not_exists=True,
)

# Condition to detect pickled data for different databases
condition_templates = {
"postgresql": "get_byte(value, 0) = 128",
"mysql": "HEX(SUBSTRING(value, 1, 1)) = '80'",
"sqlite": "substr(value, 1, 1) = char(128)",
}

condition = condition_templates.get(dialect)
if not condition:
raise RuntimeError(f"Unsupported dialect: {dialect}")

# Key is a reserved keyword in MySQL, so we need to quote it
quoted_key = conn.dialect.identifier_preparer.quote("key")

# Archive pickled data using the condition
conn.execute(
text(
f"""
INSERT INTO _xcom_archive (dag_run_id, task_id, map_index, {quoted_key}, dag_id, run_id, value, timestamp)
SELECT dag_run_id, task_id, map_index, {quoted_key}, dag_id, run_id, value, timestamp
FROM xcom
WHERE value IS NOT NULL AND {condition}
"""
)
)

# Delete the pickled data from the xcom table so that we can update the column type
conn.execute(text(f"DELETE FROM xcom WHERE value IS NOT NULL AND {condition}"))

# Update the value column type to JSON
if dialect == "postgresql":
op.execute(
"""
ALTER TABLE xcom
ALTER COLUMN value TYPE JSONB
USING CASE
WHEN value IS NOT NULL THEN CAST(CONVERT_FROM(value, 'UTF8') AS JSONB)
ELSE NULL
END
"""
)
elif dialect == "mysql":
op.add_column("xcom", sa.Column("value_json", sa.JSON(), nullable=True))
op.execute("UPDATE xcom SET value_json = CAST(value AS CHAR CHARACTER SET utf8mb4)")
op.drop_column("xcom", "value")
op.alter_column("xcom", "value_json", existing_type=sa.JSON(), new_column_name="value")
elif dialect == "sqlite":
# Rename the existing `value` column to `value_old`
with op.batch_alter_table("xcom", schema=None) as batch_op:
batch_op.alter_column("value", new_column_name="value_old")

# Add the new `value` column with JSON type
with op.batch_alter_table("xcom", schema=None) as batch_op:
batch_op.add_column(sa.Column("value", sa.JSON(), nullable=True))

# Migrate data from `value_old` to `value`
conn.execute(
text(
"""
UPDATE xcom
SET value = json(CAST(value_old AS TEXT))
WHERE value_old IS NOT NULL
"""
)
)

# Drop the old `value_old` column
with op.batch_alter_table("xcom", schema=None) as batch_op:
batch_op.drop_column("value_old")


def downgrade():
"""Unapply Remove pickled data from xcom table."""
conn = op.get_bind()
dialect = conn.dialect.name

# Revert the value column back to LargeBinary
if dialect == "postgresql":
op.execute(
"""
ALTER TABLE xcom
ALTER COLUMN value TYPE BYTEA
USING CASE
WHEN value IS NOT NULL THEN CONVERT_TO(value::TEXT, 'UTF8')
ELSE NULL
END
"""
)
elif dialect == "mysql":
op.add_column("xcom", sa.Column("value_blob", LONGBLOB, nullable=True))
op.execute("UPDATE xcom SET value_blob = CAST(value AS BINARY);")
op.drop_column("xcom", "value")
op.alter_column("xcom", "value_blob", existing_type=LONGBLOB, new_column_name="value")

elif dialect == "sqlite":
with op.batch_alter_table("xcom", schema=None) as batch_op:
batch_op.alter_column("value", new_column_name="value_old")

with op.batch_alter_table("xcom", schema=None) as batch_op:
batch_op.add_column(sa.Column("value", sa.LargeBinary, nullable=True))

conn.execute(
text(
"""
UPDATE xcom
SET value = CAST(value_old AS BLOB)
WHERE value_old IS NOT NULL
"""
)
)

with op.batch_alter_table("xcom", schema=None) as batch_op:
batch_op.drop_column("value_old")
14 changes: 8 additions & 6 deletions airflow/models/xcom.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,18 +23,17 @@
from typing import TYPE_CHECKING, Any, Iterable, cast

from sqlalchemy import (
JSON,
Column,
ForeignKeyConstraint,
Index,
Integer,
LargeBinary,
PrimaryKeyConstraint,
String,
delete,
select,
text,
)
from sqlalchemy.dialects.mysql import LONGBLOB
from sqlalchemy.ext.associationproxy import association_proxy
from sqlalchemy.orm import Query, reconstructor, relationship

Expand Down Expand Up @@ -80,7 +79,7 @@ class BaseXCom(TaskInstanceDependencies, LoggingMixin):
dag_id = Column(String(ID_LEN, **COLLATION_ARGS), nullable=False)
run_id = Column(String(ID_LEN, **COLLATION_ARGS), nullable=False)

value = Column(LargeBinary().with_variant(LONGBLOB, "mysql"))
value = Column(JSON)
timestamp = Column(UtcDateTime, default=timezone.utcnow, nullable=False)

__table_args__ = (
Expand Down Expand Up @@ -453,9 +452,12 @@ def serialize_value(
dag_id: str | None = None,
run_id: str | None = None,
map_index: int | None = None,
) -> Any:
) -> str:
"""Serialize XCom value to JSON str."""
return json.dumps(value, cls=XComEncoder).encode("UTF-8")
try:
return json.dumps(value, cls=XComEncoder)
except (ValueError, TypeError):
raise ValueError("XCom value must be JSON serializable")

@staticmethod
def _deserialize_value(result: XCom, orm: bool) -> Any:
Expand All @@ -466,7 +468,7 @@ def _deserialize_value(result: XCom, orm: bool) -> Any:
if result.value is None:
return None

return json.loads(result.value.decode("UTF-8"), cls=XComDecoder, object_hook=object_hook)
return json.loads(result.value, cls=XComDecoder, object_hook=object_hook)

@staticmethod
def deserialize_value(result: XCom) -> Any:
Expand Down
2 changes: 1 addition & 1 deletion airflow/utils/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ class MappedClassProtocol(Protocol):
"2.9.2": "686269002441",
"2.10.0": "22ed7efa9da2",
"2.10.3": "5f2621c13b39",
"3.0.0": "9fc3fc5de720",
"3.0.0": "eed27faa34e3",
}


Expand Down
2 changes: 1 addition & 1 deletion airflow/www/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -3861,7 +3861,7 @@ class XComModelView(AirflowModelView):
permissions.ACTION_CAN_ACCESS_MENU,
]

search_columns = ["key", "value", "timestamp", "dag_id", "task_id", "run_id", "logical_date"]
search_columns = ["key", "timestamp", "dag_id", "task_id", "run_id", "logical_date"]
list_columns = ["key", "value", "timestamp", "dag_id", "task_id", "run_id", "map_index", "logical_date"]
base_order = ("dag_run_id", "desc")

Expand Down
2 changes: 1 addition & 1 deletion docs/apache-airflow/img/airflow_erd.sha256
Original file line number Diff line number Diff line change
@@ -1 +1 @@
028d2fec22a15bbf5794e2fc7522eaf880a8b6293ead484780ef1a14e6cd9b48
7748eec981f977cc97b852d1fe982aebe24ec2d090ae8493a65cea101f9d42a5
2 changes: 1 addition & 1 deletion docs/apache-airflow/img/airflow_erd.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
4 changes: 3 additions & 1 deletion docs/apache-airflow/migrations-ref.rst
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,9 @@ Here's the list of all the Database Migrations that are executed via when you ru
+-------------------------+------------------+-------------------+--------------------------------------------------------------+
| Revision ID | Revises ID | Airflow Version | Description |
+=========================+==================+===================+==============================================================+
| ``9fc3fc5de720`` (head) | ``2b47dc6bc8df`` | ``3.0.0`` | Add references between assets and triggers. |
| ``eed27faa34e3`` (head) | ``9fc3fc5de720`` | ``3.0.0`` | Remove pickled data from xcom table. |
+-------------------------+------------------+-------------------+--------------------------------------------------------------+
| ``9fc3fc5de720`` | ``2b47dc6bc8df`` | ``3.0.0`` | Add references between assets and triggers. |
+-------------------------+------------------+-------------------+--------------------------------------------------------------+
| ``2b47dc6bc8df`` | ``d03e4a635aa3`` | ``3.0.0`` | add dag versioning. |
+-------------------------+------------------+-------------------+--------------------------------------------------------------+
Expand Down
4 changes: 4 additions & 0 deletions newsfragments/aip-72.significant.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,7 @@ As part of this change the following breaking changes have occurred:

If you still need to use pickling, you can use a custom XCom backend that stores references in the metadata DB and
the pickled data can be stored in a separate storage like S3.

The ``value`` field in the XCom table has been changed to a ``JSON`` type via DB migration. The XCom records that
contains pickled data are archived in the ``_xcom_archive`` table. You can safely drop this table if you don't need
the data anymore.
22 changes: 17 additions & 5 deletions providers/src/airflow/providers/common/io/xcom/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@
from urllib.parse import urlsplit

import fsspec.utils
from packaging.version import Version

from airflow import __version__ as airflow_version
from airflow.configuration import conf
from airflow.io.path import ObjectStoragePath
from airflow.models.xcom import BaseXCom
Expand All @@ -41,6 +43,10 @@
SECTION = "common.io"


AIRFLOW_VERSION = Version(airflow_version)
AIRFLOW_V_3_0_PLUS = Version(AIRFLOW_VERSION.base_version) >= Version("3.0.0")


def _get_compression_suffix(compression: str) -> str:
"""
Return the compression suffix for the given compression.
Expand Down Expand Up @@ -103,7 +109,7 @@ def _get_full_path(data: str) -> ObjectStoragePath:
raise ValueError(f"Not a valid url: {data}")

@staticmethod
def serialize_value(
def serialize_value( # type: ignore[override]
value: T,
*,
key: str | None = None,
Expand All @@ -114,16 +120,22 @@ def serialize_value(
) -> bytes | str:
# we will always serialize ourselves and not by BaseXCom as the deserialize method
# from BaseXCom accepts only XCom objects and not the value directly
s_val = json.dumps(value, cls=XComEncoder).encode("utf-8")
s_val = json.dumps(value, cls=XComEncoder)
s_val_encoded = s_val.encode("utf-8")

if compression := _get_compression():
suffix = f".{_get_compression_suffix(compression)}"
else:
suffix = ""

threshold = _get_threshold()
if threshold < 0 or len(s_val) < threshold: # Either no threshold or value is small enough.
return s_val
if threshold < 0 or len(s_val_encoded) < threshold: # Either no threshold or value is small enough.
if AIRFLOW_V_3_0_PLUS:
return s_val
else:
# TODO: Remove this branch once we drop support for Airflow 2
# This is for Airflow 2.10 where the value is expected to be bytes
return s_val_encoded

base_path = _get_base_path()
while True: # Safeguard against collisions.
Expand All @@ -138,7 +150,7 @@ def serialize_value(
p.parent.mkdir(parents=True, exist_ok=True)

with p.open(mode="wb", compression=compression) as f:
f.write(s_val)
f.write(s_val_encoded)
return BaseXCom.serialize_value(str(p))

@staticmethod
Expand Down
2 changes: 1 addition & 1 deletion tests/api_connexion/endpoints/test_xcom_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -640,7 +640,7 @@ def test_handle_limit_offset(self, query_params, expected_xcom_ids):
xcom = XCom(
dag_run_id=dagrun.id,
key=f"TEST_XCOM_KEY{i}",
value=b"null",
value="null",
run_id=self.run_id,
task_id=self.task_id,
dag_id=self.dag_id,
Expand Down
6 changes: 3 additions & 3 deletions tests/models/test_xcom.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,14 +111,14 @@ def test_resolve_xcom_class(self):
def test_resolve_xcom_class_fallback_to_basexcom(self):
cls = resolve_xcom_backend()
assert issubclass(cls, BaseXCom)
assert cls.serialize_value([1]) == b"[1]"
assert cls.serialize_value([1]) == "[1]"

@conf_vars({("core", "xcom_backend"): "to be removed"})
def test_resolve_xcom_class_fallback_to_basexcom_no_config(self):
conf.remove_option("core", "xcom_backend")
cls = resolve_xcom_backend()
assert issubclass(cls, BaseXCom)
assert cls.serialize_value([1]) == b"[1]"
assert cls.serialize_value([1]) == "[1]"

@mock.patch("airflow.models.xcom.XCom.orm_deserialize_value")
def test_xcom_init_on_load_uses_orm_deserialize_value(self, mock_orm_deserialize):
Expand Down Expand Up @@ -182,7 +182,7 @@ def serialize_value(
run_id=run_id,
map_index=map_index,
)
return json.dumps(value).encode("utf-8")
return json.dumps(value)

get_import.return_value = CurrentSignatureXCom

Expand Down

0 comments on commit 86c4c6f

Please sign in to comment.