diff --git a/src/datachain/delta.py b/src/datachain/delta.py index 1567cf141..9704cdb58 100644 --- a/src/datachain/delta.py +++ b/src/datachain/delta.py @@ -125,7 +125,15 @@ def _get_retry_chain( # Subtract also diff chain since some items might be picked # up by `delta=True` itself (e.g. records got modified AND are missing in the # result dataset atm) - return retry_chain.subtract(diff_chain, on=on) if retry_chain else None + on = [on] if isinstance(on, str) else on + + return ( + retry_chain.diff( + diff_chain, on=on, added=True, same=True, modified=False, deleted=False + ).distinct(*on) + if retry_chain + else None + ) def _get_source_info( diff --git a/tests/func/test_retry.py b/tests/func/test_retry.py index e8f025610..1178c150b 100644 --- a/tests/func/test_retry.py +++ b/tests/func/test_retry.py @@ -1,3 +1,4 @@ +from collections.abc import Iterator from datetime import datetime, timezone from typing import TYPE_CHECKING @@ -425,3 +426,42 @@ def test_delta_and_delta_retry_no_duplicates(test_session): assert len(ids_in_result) == 4 assert len(set(ids_in_result)) == 4 # No duplicate IDs assert set(ids_in_result) == {1, 2, 3, 4} + + +def test_repeating_errors(test_session): + def run_delta(): + def func(id) -> Iterator[tuple[int, str, str]]: + yield id, "name1", "error" + yield id, "name2", "error" + + return ( + dc.read_dataset( + "sample_data", + delta=True, + delta_on="id", + delta_result_on="id", + delta_retry="error", + session=test_session, + ) + .gen(func, output={"id": int, "name": str, "error": str}) + .save("processed_data") + ) + return dc.read_dataset("processed_data") + + _create_sample_data( + test_session, ids=list(range(1)), contents=[str(i) for i in range(1)] + ) + ch1 = run_delta() + assert sorted(ch1.collect("id")) == [0, 0] + + _create_sample_data( + test_session, ids=list(range(2)), contents=[str(i) for i in range(2)] + ) + ch2 = run_delta() + assert sorted(ch2.collect("id")) == [0, 0, 1, 1] + + _create_sample_data( + test_session, ids=list(range(3)), contents=[str(i) for i in range(3)] + ) + ch3 = run_delta() + assert sorted(ch3.collect("id")) == [0, 0, 1, 1, 2, 2] diff --git a/tests/unit/lib/test_datachain.py b/tests/unit/lib/test_datachain.py index e958b675c..80e696cbd 100644 --- a/tests/unit/lib/test_datachain.py +++ b/tests/unit/lib/test_datachain.py @@ -2230,6 +2230,13 @@ def test_subtract(test_session): assert set(chain4.subtract(chain5, on="d", right_on="a").to_list()) == {(3, "z")} +def test_subtract_duplicated_rows(test_session): + chain1 = dc.read_values(id=[1, 1], name=["1", "1"], session=test_session) + chain2 = dc.read_values(id=[2], name=["2"], session=test_session) + sub = chain1.subtract(chain2, on="id") + assert set(sub.to_list()) == {(1, "1"), (1, "1")} + + def test_subtract_error(test_session): chain1 = dc.read_values(a=[1, 1, 2], b=["x", "y", "z"], session=test_session) chain2 = dc.read_values(a=[1, 2], b=["x", "y"], session=test_session)