From d4af84ac3b1925a366b801149a2dc3ae929fa205 Mon Sep 17 00:00:00 2001 From: Ivan Shcheklein Date: Sun, 19 Oct 2025 11:16:06 -0700 Subject: [PATCH] fix delta expecting sys columns in apply steps --- src/datachain/delta.py | 33 +++++++++++++++++++++++++++++++++ tests/func/test_delta.py | 37 +++++++++++++++++++++++++++++++++++++ 2 files changed, 70 insertions(+) diff --git a/src/datachain/delta.py b/src/datachain/delta.py index bb33b1ef4..757c017a7 100644 --- a/src/datachain/delta.py +++ b/src/datachain/delta.py @@ -1,12 +1,16 @@ +import hashlib from collections.abc import Sequence from copy import copy from functools import wraps from typing import TYPE_CHECKING, TypeVar +from attrs import frozen + import datachain from datachain.dataset import DatasetDependency, DatasetRecord from datachain.error import DatasetNotFoundError from datachain.project import Project +from datachain.query.dataset import Step, step_result if TYPE_CHECKING: from collections.abc import Callable @@ -14,7 +18,9 @@ from typing_extensions import ParamSpec + from datachain.catalog import Catalog from datachain.lib.dc import DataChain + from datachain.query.dataset import QueryGenerator P = ParamSpec("P") @@ -43,11 +49,38 @@ def _inner(self: T, *args: "P.args", **kwargs: "P.kwargs") -> T: return _inner +@frozen +class _RegenerateSystemColumnsStep(Step): + catalog: "Catalog" + + def hash_inputs(self) -> str: + return hashlib.sha256(b"regenerate_sys_columns").hexdigest() + + def apply(self, query_generator: "QueryGenerator", temp_tables: list[str]): + selectable = query_generator.select() + regenerated = self.catalog.warehouse._regenerate_system_columns( + selectable, + keep_existing_columns=True, + regenerate_columns=None, + ) + + def q(*columns): + return regenerated.with_only_columns(*columns) + + return step_result(q, regenerated.selected_columns) + + def _append_steps(dc: "DataChain", other: "DataChain"): """Returns cloned chain with appended steps from other chain. Steps are all those modification methods applied like filters, mappers etc. """ dc = dc.clone() + dc._query.steps.append( + _RegenerateSystemColumnsStep( + catalog=dc.session.catalog, + ) + ) + dc._query.steps += other._query.steps.copy() dc.signals_schema = other.signals_schema return dc diff --git a/tests/func/test_delta.py b/tests/func/test_delta.py index 5165f94fd..62cd79085 100644 --- a/tests/func/test_delta.py +++ b/tests/func/test_delta.py @@ -1,4 +1,5 @@ import os +import uuid import pytest import regex as re @@ -224,6 +225,42 @@ def test_delta_update_unsafe(test_session): } +def test_delta_replay_regenerates_system_columns(test_session): + source_name = f"regen_source_{uuid.uuid4().hex[:8]}" + result_name = f"regen_result_{uuid.uuid4().hex[:8]}" + + dc.read_values( + measurement_id=[1, 2], + err=["", ""], + num=[1, 2], + session=test_session, + ).save(source_name) + + def build_chain(delta: bool): + read_kwargs = {"session": test_session} + if delta: + read_kwargs.update({"delta": True, "delta_on": "measurement_id"}) + return ( + dc.read_dataset(source_name, **read_kwargs) + .filter(C.err == "") + .select_except("err") + .map(double=lambda num: num * 2, output=int) + .select_except("num") + ) + + build_chain(delta=False).save(result_name) + + build_chain(delta=True).save( + result_name, + delta=True, + delta_on="measurement_id", + ) + + assert set( + dc.read_dataset(result_name, session=test_session).to_values("measurement_id") + ) == {1, 2} + + def test_delta_update_from_storage(test_session, tmp_dir, tmp_path): ds_name = "delta_ds" path = tmp_dir.as_uri()