diff --git a/docs/guide/delta.md b/docs/guide/delta.md index 751331293..f547a9400 100644 --- a/docs/guide/delta.md +++ b/docs/guide/delta.md @@ -80,3 +80,23 @@ Delta processing can be combined with [retry processing](./retry.md) to create a 1. Processes only new or changed records (delta) 2. Reprocesses records with errors or that are missing (retry) + +## Using Delta with Restricted Methods + +By default, delta updates cannot be combined with the following methods: + +1. `merge` +2. `union` +3. `distinct` +4. `agg` +5. `group_by` + +These methods are restricted because they may produce **unexpected results** when used with delta processing. Delta runs the chain only on a subset of rows (new and changed records), while methods like `distinct`, `agg`, or `group_by` are designed to operate on the entire dataset. + +Similarly, combining delta with methods like `merge` or `union` may result in duplicated rows when merging with a static dataset. + +If you still need to use these methods together with delta, you can override this restriction by setting the additional flag: + +```python +delta_unsafe=True +``` diff --git a/src/datachain/delta.py b/src/datachain/delta.py index 1c3792abe..1567cf141 100644 --- a/src/datachain/delta.py +++ b/src/datachain/delta.py @@ -4,7 +4,7 @@ from typing import TYPE_CHECKING, Callable, Optional, TypeVar, Union import datachain -from datachain.dataset import DatasetDependency +from datachain.dataset import DatasetDependency, DatasetRecord from datachain.error import DatasetNotFoundError from datachain.project import Project @@ -30,9 +30,10 @@ def delta_disabled( @wraps(method) def _inner(self: T, *args: "P.args", **kwargs: "P.kwargs") -> T: - if self.delta: + if self.delta and not self._delta_unsafe: raise NotImplementedError( - f"Delta update cannot be used with {method.__name__}" + f"Cannot use {method.__name__} with delta datasets - may cause" + " inconsistency. Use delta_unsafe flag to allow this operation." ) return method(self, *args, **kwargs) @@ -128,6 +129,7 @@ def _get_retry_chain( def _get_source_info( + source_ds: DatasetRecord, name: str, namespace_name: str, project_name: str, @@ -154,25 +156,23 @@ def _get_source_info( indirect=False, ) - dep = dependencies[0] - if not dep: + source_ds_dep = next((d for d in dependencies if d.name == source_ds.name), None) + if not source_ds_dep: # Starting dataset was removed, back off to normal dataset creation return None, None, None, None, None - source_ds_project = catalog.metastore.get_project(dep.project, dep.namespace) - source_ds_name = dep.name - source_ds_version = dep.version - source_ds_latest_version = catalog.get_dataset( - source_ds_name, - namespace_name=source_ds_project.namespace.name, - project_name=source_ds_project.name, - ).latest_version + # Refresh starting dataset to have new versions if they are created + source_ds = catalog.get_dataset( + source_ds.name, + namespace_name=source_ds.project.namespace.name, + project_name=source_ds.project.name, + ) return ( - source_ds_name, - source_ds_project, - source_ds_version, - source_ds_latest_version, + source_ds.name, + source_ds.project, + source_ds_dep.version, + source_ds.latest_version, dependencies, ) @@ -244,7 +244,14 @@ def delta_retry_update( source_ds_version, source_ds_latest_version, dependencies, - ) = _get_source_info(name, namespace_name, project_name, latest_version, catalog) + ) = _get_source_info( + dc._query.starting_step.dataset, # type: ignore[union-attr] + name, + namespace_name, + project_name, + latest_version, + catalog, + ) # If source_ds_name is None, starting dataset was removed if source_ds_name is None: @@ -267,8 +274,9 @@ def delta_retry_update( if dependencies: dependencies = copy(dependencies) dependencies = [d for d in dependencies if d is not None] + source_ds_dep = next(d for d in dependencies if d.name == source_ds_name) # Update to latest version - dependencies[0].version = source_ds_latest_version # type: ignore[union-attr] + source_ds_dep.version = source_ds_latest_version # type: ignore[union-attr] # Handle retry functionality if enabled if delta_retry: diff --git a/src/datachain/lib/dc/datachain.py b/src/datachain/lib/dc/datachain.py index 8ef1cc416..2f3b1aae8 100644 --- a/src/datachain/lib/dc/datachain.py +++ b/src/datachain/lib/dc/datachain.py @@ -193,6 +193,7 @@ def __init__( self._setup: dict = setup or {} self._sys = _sys self._delta = False + self._delta_unsafe = False self._delta_on: Optional[Union[str, Sequence[str]]] = None self._delta_result_on: Optional[Union[str, Sequence[str]]] = None self._delta_compare: Optional[Union[str, Sequence[str]]] = None @@ -216,6 +217,7 @@ def _as_delta( right_on: Optional[Union[str, Sequence[str]]] = None, compare: Optional[Union[str, Sequence[str]]] = None, delta_retry: Optional[Union[bool, str]] = None, + delta_unsafe: bool = False, ) -> "Self": """Marks this chain as delta, which means special delta process will be called on saving dataset for optimization""" @@ -226,6 +228,7 @@ def _as_delta( self._delta_result_on = right_on self._delta_compare = compare self._delta_retry = delta_retry + self._delta_unsafe = delta_unsafe return self @property @@ -238,6 +241,10 @@ def delta(self) -> bool: """Returns True if this chain is ran in "delta" update mode""" return self._delta + @property + def delta_unsafe(self) -> bool: + return self._delta_unsafe + @property def schema(self) -> dict[str, DataType]: """Get schema of the chain.""" @@ -328,6 +335,7 @@ def _evolve( right_on=self._delta_result_on, compare=self._delta_compare, delta_retry=self._delta_retry, + delta_unsafe=self._delta_unsafe, ) return chain diff --git a/src/datachain/lib/dc/datasets.py b/src/datachain/lib/dc/datasets.py index 21c34d88c..f352d0d3e 100644 --- a/src/datachain/lib/dc/datasets.py +++ b/src/datachain/lib/dc/datasets.py @@ -40,6 +40,7 @@ def read_dataset( delta_result_on: Optional[Union[str, Sequence[str]]] = None, delta_compare: Optional[Union[str, Sequence[str]]] = None, delta_retry: Optional[Union[bool, str]] = None, + delta_unsafe: bool = False, update: bool = False, ) -> "DataChain": """Get data from a saved Dataset. It returns the chain itself. @@ -80,6 +81,8 @@ def read_dataset( update: If True always checks for newer versions available on Studio, even if some version of the dataset exists locally already. If False (default), it will only fetch the dataset from Studio if it is not found locally. + delta_unsafe: Allow restricted ops in delta: merge, agg, union, group_by, + distinct. Example: @@ -205,6 +208,7 @@ def read_dataset( right_on=delta_result_on, compare=delta_compare, delta_retry=delta_retry, + delta_unsafe=delta_unsafe, ) return chain diff --git a/src/datachain/lib/dc/storage.py b/src/datachain/lib/dc/storage.py index 517e5bce4..a54b3e557 100644 --- a/src/datachain/lib/dc/storage.py +++ b/src/datachain/lib/dc/storage.py @@ -43,6 +43,7 @@ def read_storage( delta_result_on: Optional[Union[str, Sequence[str]]] = None, delta_compare: Optional[Union[str, Sequence[str]]] = None, delta_retry: Optional[Union[bool, str]] = None, + delta_unsafe: bool = False, client_config: Optional[dict] = None, ) -> "DataChain": """Get data from storage(s) as a list of file with all file attributes. @@ -77,6 +78,9 @@ def read_storage( (error mode) - True: Reprocess records missing from the result dataset (missing mode) - None: No retry processing (default) + delta_unsafe: Allow restricted ops in delta: merge, agg, union, group_by, + distinct. Caller must ensure datasets are consistent and not partially + updated. Returns: DataChain: A DataChain object containing the file information. @@ -218,6 +222,7 @@ def lst_fn(ds_name, lst_uri): right_on=delta_result_on, compare=delta_compare, delta_retry=delta_retry, + delta_unsafe=delta_unsafe, ) return storage_chain diff --git a/tests/func/test_delta.py b/tests/func/test_delta.py index 6025ff1ed..b34caf7c9 100644 --- a/tests/func/test_delta.py +++ b/tests/func/test_delta.py @@ -14,26 +14,16 @@ def _get_dependencies(catalog, name, version) -> list[tuple[str, str]]: return sorted( [ - (f"{d.namespace}.{d.project}.{d.name}", d.version) + (d.name, d.version) for d in catalog.get_dataset_dependencies(name, version, indirect=False) ] ) -@pytest.mark.parametrize("project", ("global.dev", "")) -def test_delta_update_from_dataset(test_session, tmp_dir, tmp_path, project): +def test_delta_update_from_dataset(test_session, tmp_dir, tmp_path): catalog = test_session.catalog - default_namespace_name = catalog.metastore.default_namespace_name - default_project_name = catalog.metastore.default_project_name - - if project: - starting_ds_name = f"{project}.starting_ds" - dependency_ds_name = starting_ds_name - else: - starting_ds_name = "starting_ds" - dependency_ds_name = ( - f"{default_namespace_name}.{default_project_name}.{starting_ds_name}" - ) + + starting_ds_name = "starting_ds" ds_name = "delta_ds" images = [ @@ -66,16 +56,12 @@ def create_delta_dataset(ds_name): create_image_dataset(starting_ds_name, images[:2]) # first version of delta dataset create_delta_dataset(ds_name) - assert _get_dependencies(catalog, ds_name, "1.0.0") == [ - (dependency_ds_name, "1.0.0") - ] + assert _get_dependencies(catalog, ds_name, "1.0.0") == [(starting_ds_name, "1.0.0")] # second version of starting dataset create_image_dataset(starting_ds_name, images[2:]) # second version of delta dataset create_delta_dataset(ds_name) - assert _get_dependencies(catalog, ds_name, "1.0.1") == [ - (dependency_ds_name, "1.0.1") - ] + assert _get_dependencies(catalog, ds_name, "1.0.1") == [(starting_ds_name, "1.0.1")] assert (dc.read_dataset(ds_name, version="1.0.0").order_by("file.path")).to_values( "file.path" @@ -96,6 +82,66 @@ def create_delta_dataset(ds_name): create_delta_dataset(ds_name) +def test_delta_update_unsafe(test_session): + catalog = test_session.catalog + + starting_ds_name = "starting_ds" + merge_ds_name = "merge_ds" + ds_name = "delta_ds" + + # create dataset which will be merged to delta one + merge_ds = dc.read_values( + id=[1, 2, 3, 4, 5, 6], value=[1, 2, 3, 4, 5, 6], session=test_session + ).save(merge_ds_name) + + # first version of starting dataset + dc.read_values(id=[1, 2, 3], session=test_session).save(starting_ds_name) + # first version of delta dataset + dc.read_dataset( + starting_ds_name, + session=test_session, + delta_on="id", + delta=True, + delta_unsafe=True, + ).merge(merge_ds, on="id", inner=True).save(ds_name) + + assert set(_get_dependencies(catalog, ds_name, "1.0.0")) == { + (starting_ds_name, "1.0.0"), + (merge_ds_name, "1.0.0"), + } + + # second version of starting dataset + dc.read_values(id=[1, 2, 3, 4, 5, 6], session=test_session).save(starting_ds_name) + # second version of delta dataset + dc.read_dataset( + starting_ds_name, + session=test_session, + delta_on="id", + delta=True, + delta_unsafe=True, + ).merge(merge_ds, on="id", inner=True).save(ds_name) + + assert set(_get_dependencies(catalog, ds_name, "1.0.1")) == { + (starting_ds_name, "1.0.1"), + (merge_ds_name, "1.0.0"), + } + + assert set((dc.read_dataset(ds_name, version="1.0.0")).to_list("id", "value")) == { + (1, 1), + (2, 2), + (3, 3), + } + + assert set((dc.read_dataset(ds_name, version="1.0.1")).to_list("id", "value")) == { + (1, 1), + (2, 2), + (3, 3), + (4, 4), + (5, 5), + (6, 6), + } + + def test_delta_update_from_storage(test_session, tmp_dir, tmp_path): ds_name = "delta_ds" path = tmp_dir.as_uri() @@ -249,8 +295,6 @@ def get_index(file: File) -> int: def test_delta_update_no_diff(test_session, tmp_dir, tmp_path): catalog = test_session.catalog - default_namespace_name = catalog.metastore.default_namespace_name - default_project_name = catalog.metastore.default_project_name ds_name = "delta_ds" path = tmp_dir.as_uri() tmp_dir = tmp_dir / "images" @@ -301,7 +345,8 @@ def get_index(file: File) -> int: assert str(exc_info.value) == ( f"Dataset {ds_name} version 1.0.1 not found in namespace " - f"{default_namespace_name} and project {default_project_name}" + f"{catalog.metastore.default_namespace_name}" + f" and project {catalog.metastore.default_project_name}" ) @@ -325,11 +370,13 @@ def test_delta_update_union(test_session, file_dataset): file_dataset.name, session=test_session, delta=True, - delta_on=["file.source", "file.path"], ).union(dc.read_dataset("numbers"), session=test_session) ) - assert str(excinfo.value) == "Delta update cannot be used with union" + assert str(excinfo.value) == ( + "Cannot use union with delta datasets - may cause inconsistency." + " Use delta_unsafe flag to allow this operation." + ) def test_delta_update_merge(test_session, file_dataset): @@ -341,11 +388,13 @@ def test_delta_update_merge(test_session, file_dataset): file_dataset.name, session=test_session, delta=True, - delta_on=["file.source", "file.path"], ).merge(dc.read_dataset("numbers"), on="id", session=test_session) ) - assert str(excinfo.value) == "Delta update cannot be used with merge" + assert str(excinfo.value) == ( + "Cannot use merge with delta datasets - may cause inconsistency." + " Use delta_unsafe flag to allow this operation." + ) def test_delta_update_distinct(test_session, file_dataset): @@ -355,11 +404,13 @@ def test_delta_update_distinct(test_session, file_dataset): file_dataset.name, session=test_session, delta=True, - delta_on=["file.source", "file.path"], ).distinct("file.path") ) - assert str(excinfo.value) == "Delta update cannot be used with distinct" + assert str(excinfo.value) == ( + "Cannot use distinct with delta datasets - may cause inconsistency." + " Use delta_unsafe flag to allow this operation." + ) def test_delta_update_group_by(test_session, file_dataset): @@ -369,11 +420,13 @@ def test_delta_update_group_by(test_session, file_dataset): file_dataset.name, session=test_session, delta=True, - delta_on=["file.source", "file.path"], ).group_by(cnt=func.count(), partition_by="file.path") ) - assert str(excinfo.value) == "Delta update cannot be used with group_by" + assert str(excinfo.value) == ( + "Cannot use group_by with delta datasets - may cause inconsistency." + " Use delta_unsafe flag to allow this operation." + ) def test_delta_update_agg(test_session, file_dataset): @@ -383,8 +436,10 @@ def test_delta_update_agg(test_session, file_dataset): file_dataset.name, session=test_session, delta=True, - delta_on=["file.source", "file.path"], ).agg(cnt=func.count(), partition_by="file.path") ) - assert str(excinfo.value) == "Delta update cannot be used with agg" + assert str(excinfo.value) == ( + "Cannot use agg with delta datasets - may cause inconsistency." + " Use delta_unsafe flag to allow this operation." + )