Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added toolkit compare #719

Merged
merged 10 commits into from
Jan 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 116 additions & 12 deletions src/datachain/lib/diff.py → src/datachain/diff/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import random
import string
from collections.abc import Sequence
from enum import Enum
from typing import TYPE_CHECKING, Optional, Union

import sqlalchemy as sa
Expand All @@ -16,7 +17,22 @@
C = Column


def compare( # noqa: PLR0912, PLR0915, C901
def get_status_col_name() -> str:
"""Returns new unique status col name"""
return "diff_" + "".join(
random.choice(string.ascii_letters) # noqa: S311
for _ in range(10)
)


class CompareStatus(str, Enum):
ADDED = "A"
DELETED = "D"
MODIFIED = "M"
SAME = "S"


def _compare( # noqa: PLR0912, PLR0915, C901
left: "DataChain",
right: "DataChain",
on: Union[str, Sequence[str]],
Expand Down Expand Up @@ -72,13 +88,10 @@ def _to_list(obj: Union[str, Sequence[str]]) -> list[str]:
"At least one of added, deleted, modified, same flags must be set"
)

# we still need status column for internal implementation even if not
# needed in output
need_status_col = bool(status_col)
status_col = status_col or "diff_" + "".join(
random.choice(string.ascii_letters) # noqa: S311
for _ in range(10)
)
# we still need status column for internal implementation even if not
# needed in the output
status_col = status_col or get_status_col_name()

# calculate on and compare column names
right_on = right_on or on
Expand Down Expand Up @@ -112,25 +125,27 @@ def _to_list(obj: Union[str, Sequence[str]]) -> list[str]:
for c in [f"{_rprefix(c, rc)}{rc}" for c, rc in zip(on, right_on)]
]
)
diff_cond.append((added_cond, "A"))
diff_cond.append((added_cond, CompareStatus.ADDED))
if modified and compare:
modified_cond = sa.or_(
*[
C(c) != C(f"{_rprefix(c, rc)}{rc}")
for c, rc in zip(compare, right_compare) # type: ignore[arg-type]
]
)
diff_cond.append((modified_cond, "M"))
diff_cond.append((modified_cond, CompareStatus.MODIFIED))
if same and compare:
same_cond = sa.and_(
*[
C(c) == C(f"{_rprefix(c, rc)}{rc}")
for c, rc in zip(compare, right_compare) # type: ignore[arg-type]
]
)
diff_cond.append((same_cond, "S"))
diff_cond.append((same_cond, CompareStatus.SAME))

diff = sa.case(*diff_cond, else_=None if compare else "M").label(status_col)
diff = sa.case(*diff_cond, else_=None if compare else CompareStatus.MODIFIED).label(
status_col
)
diff.type = String()

left_right_merge = left.merge(
Expand All @@ -145,7 +160,7 @@ def _to_list(obj: Union[str, Sequence[str]]) -> list[str]:
)
)

diff_col = sa.literal("D").label(status_col)
diff_col = sa.literal(CompareStatus.DELETED).label(status_col)
diff_col.type = String()

right_left_merge = right.merge(
Expand Down Expand Up @@ -195,3 +210,92 @@ def _default_val(chain: "DataChain", col: str):
res = res.select_except(C(status_col))

return left._evolve(query=res, signal_schema=schema)


def compare_and_split(
left: "DataChain",
right: "DataChain",
on: Union[str, Sequence[str]],
right_on: Optional[Union[str, Sequence[str]]] = None,
compare: Optional[Union[str, Sequence[str]]] = None,
right_compare: Optional[Union[str, Sequence[str]]] = None,
added: bool = True,
deleted: bool = True,
modified: bool = True,
same: bool = False,
) -> dict[str, "DataChain"]:
"""Comparing two chains and returning multiple chains, one for each of `added`,
`deleted`, `modified` and `same` status. Result is returned in form of
dictionary where each item represents one of the statuses and key values
are `A`, `D`, `M`, `S` corresponding. Note that status column is not in the
resulting chains.
Parameters:
left: Chain to calculate diff on.
right: Chain to calculate diff from.
on: Column or list of columns to match on. If both chains have the
same columns then this column is enough for the match. Otherwise,
`right_on` parameter has to specify the columns for the other chain.
This value is used to find corresponding row in other dataset. If not
found there, row is considered as added (or removed if vice versa), and
if found then row can be either modified or same.
right_on: Optional column or list of columns
for the `other` to match.
compare: Column or list of columns to compare on. If both chains have
the same columns then this column is enough for the compare. Otherwise,
`right_compare` parameter has to specify the columns for the other
chain. This value is used to see if row is modified or same. If
not set, all columns will be used for comparison
right_compare: Optional column or list of columns
for the `other` to compare to.
added (bool): Whether to return chain containing only added rows.
deleted (bool): Whether to return chain containing only deleted rows.
modified (bool): Whether to return chain containing only modified rows.
same (bool): Whether to return chain containing only same rows.
Example:
```py
chains = compare(
persons,
new_persons,
on=["id"],
right_on=["other_id"],
compare=["name"],
added=True,
deleted=True,
modified=True,
same=True,
)
```
"""
status_col = get_status_col_name()

res = _compare(
left,
right,
on,
right_on=right_on,
compare=compare,
right_compare=right_compare,
added=added,
deleted=deleted,
modified=modified,
same=same,
status_col=status_col,
)

chains = {}

def filter_by_status(compare_status) -> "DataChain":
return res.filter(C(status_col) == compare_status).select_except(status_col)

if added:
chains[CompareStatus.ADDED.value] = filter_by_status(CompareStatus.ADDED)
if deleted:
chains[CompareStatus.DELETED.value] = filter_by_status(CompareStatus.DELETED)
if modified:
chains[CompareStatus.MODIFIED.value] = filter_by_status(CompareStatus.MODIFIED)
if same:
chains[CompareStatus.SAME.value] = filter_by_status(CompareStatus.SAME)

return chains
6 changes: 3 additions & 3 deletions src/datachain/lib/dc.py
Original file line number Diff line number Diff line change
Expand Up @@ -1536,7 +1536,7 @@ def compare(

Example:
```py
diff = persons.diff(
res = persons.compare(
new_persons,
on=["id"],
right_on=["other_id"],
Expand All @@ -1549,9 +1549,9 @@ def compare(
)
```
"""
from datachain.lib.diff import compare as chain_compare
from datachain.diff import _compare

return chain_compare(
return _compare(
self,
other,
on,
Expand Down
92 changes: 49 additions & 43 deletions tests/unit/lib/test_diff.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pytest
from pydantic import BaseModel

from datachain.diff import CompareStatus
from datachain.lib.dc import DataChain
from datachain.lib.file import File
from datachain.sql.types import Int64, String
Expand Down Expand Up @@ -58,13 +59,13 @@ def test_compare(test_session, added, deleted, modified, same, status_col, save)

expected = []
if modified:
expected.append(("M", 1, "John1"))
expected.append((CompareStatus.MODIFIED, 1, "John1"))
if added:
expected.append(("A", 2, "Doe"))
expected.append((CompareStatus.ADDED, 2, "Doe"))
if deleted:
expected.append(("D", 3, "Mark"))
expected.append((CompareStatus.DELETED, 3, "Mark"))
if same:
expected.append(("S", 4, "Andy"))
expected.append((CompareStatus.SAME, 4, "Andy"))

collect_fields = ["diff", "id", "name"]
if not status_col:
Expand Down Expand Up @@ -94,10 +95,10 @@ def test_compare_with_from_dataset(test_session):
diff = ds1.compare(ds2, same=True, on=["id"], status_col="diff")

assert list(diff.order_by("id").collect("diff", "id", "name")) == [
("M", 1, "John1"),
("A", 2, "Doe"),
("D", 3, "Mark"),
("S", 4, "Andy"),
(CompareStatus.MODIFIED, 1, "John1"),
(CompareStatus.ADDED, 2, "Doe"),
(CompareStatus.DELETED, 3, "Mark"),
(CompareStatus.SAME, 4, "Andy"),
]


Expand Down Expand Up @@ -144,20 +145,20 @@ def test_compare_with_explicit_compare_fields(

expected = []
if modified:
expected.append(("M", 1, "John1", "New York"))
expected.append((CompareStatus.MODIFIED, 1, "John1", "New York"))
if added:
expected.append(("A", 2, "Doe", "Boston"))
expected.append((CompareStatus.ADDED, 2, "Doe", "Boston"))
if deleted:
expected.append(
(
"D",
CompareStatus.DELETED,
3,
string_default if right_name == "other_name" else "Mark",
"Seattle",
)
)
if same:
expected.append(("S", 4, "Andy", "San Francisco"))
expected.append((CompareStatus.SAME, 4, "Andy", "San Francisco"))

collect_fields = ["diff", "id", "name", "city"]
assert list(diff.order_by("id").collect(*collect_fields)) == expected
Expand Down Expand Up @@ -200,13 +201,13 @@ def test_compare_different_left_right_on_columns(

expected = []
if same:
expected.append(("S", 4, "Andy"))
expected.append((CompareStatus.SAME, 4, "Andy"))
if added:
expected.append(("A", 2, "Doe"))
expected.append((CompareStatus.ADDED, 2, "Doe"))
if modified:
expected.append(("M", 1, "John1"))
expected.append((CompareStatus.MODIFIED, 1, "John1"))
if deleted:
expected.append(("D", int_default, "Mark"))
expected.append((CompareStatus.DELETED, int_default, "Mark"))

collect_fields = ["diff", "id", "name"]
assert list(diff.order_by("name").collect(*collect_fields)) == expected
Expand Down Expand Up @@ -252,9 +253,9 @@ def test_compare_on_equal_datasets(
expected = []
else:
expected = [
("S", 1, "John"),
("S", 2, "Doe"),
("S", 3, "Andy"),
(CompareStatus.SAME, 1, "John"),
(CompareStatus.SAME, 2, "Doe"),
(CompareStatus.SAME, 3, "Andy"),
]

collect_fields = ["diff", "id", "name"]
Expand All @@ -279,10 +280,10 @@ def test_compare_multiple_columns(test_session):

assert sorted_dicts(diff.to_records(), "id") == sorted_dicts(
[
{"diff": "M", "id": 1, "name": "John", "city": "London"},
{"diff": "A", "id": 2, "name": "Doe", "city": "New York"},
{"diff": "D", "id": 3, "name": "Mark", "city": "Berlin"},
{"diff": "S", "id": 4, "name": "Andy", "city": "Tokyo"},
{"diff": CompareStatus.MODIFIED, "id": 1, "name": "John", "city": "London"},
{"diff": CompareStatus.ADDED, "id": 2, "name": "Doe", "city": "New York"},
{"diff": CompareStatus.DELETED, "id": 3, "name": "Mark", "city": "Berlin"},
{"diff": CompareStatus.SAME, "id": 4, "name": "Andy", "city": "Tokyo"},
],
"id",
)
Expand All @@ -306,10 +307,10 @@ def test_compare_multiple_match_columns(test_session):

assert sorted_dicts(diff.to_records(), "id") == sorted_dicts(
[
{"diff": "M", "id": 1, "name": "John", "city": "London"},
{"diff": "A", "id": 2, "name": "Doe", "city": "New York"},
{"diff": "D", "id": 3, "name": "John", "city": "Berlin"},
{"diff": "S", "id": 4, "name": "Andy", "city": "Tokyo"},
{"diff": CompareStatus.MODIFIED, "id": 1, "name": "John", "city": "London"},
{"diff": CompareStatus.ADDED, "id": 2, "name": "Doe", "city": "New York"},
{"diff": CompareStatus.DELETED, "id": 3, "name": "John", "city": "Berlin"},
{"diff": CompareStatus.SAME, "id": 4, "name": "Andy", "city": "Tokyo"},
],
"id",
)
Expand All @@ -334,10 +335,15 @@ def test_compare_additional_column_on_left(test_session):

assert sorted_dicts(diff.to_records(), "id") == sorted_dicts(
[
{"diff": "M", "id": 1, "name": "John", "city": "London"},
{"diff": "A", "id": 2, "name": "Doe", "city": "New York"},
{"diff": "D", "id": 3, "name": "Mark", "city": string_default},
{"diff": "M", "id": 4, "name": "Andy", "city": "Tokyo"},
{"diff": CompareStatus.MODIFIED, "id": 1, "name": "John", "city": "London"},
{"diff": CompareStatus.ADDED, "id": 2, "name": "Doe", "city": "New York"},
{
"diff": CompareStatus.DELETED,
"id": 3,
"name": "Mark",
"city": string_default,
},
{"diff": CompareStatus.MODIFIED, "id": 4, "name": "Andy", "city": "Tokyo"},
],
"id",
)
Expand All @@ -360,10 +366,10 @@ def test_compare_additional_column_on_right(test_session):

assert sorted_dicts(diff.to_records(), "id") == sorted_dicts(
[
{"diff": "M", "id": 1, "name": "John"},
{"diff": "A", "id": 2, "name": "Doe"},
{"diff": "D", "id": 3, "name": "Mark"},
{"diff": "M", "id": 4, "name": "Andy"},
{"diff": CompareStatus.MODIFIED, "id": 1, "name": "John"},
{"diff": CompareStatus.ADDED, "id": 2, "name": "Doe"},
{"diff": CompareStatus.DELETED, "id": 3, "name": "Mark"},
{"diff": CompareStatus.MODIFIED, "id": 4, "name": "Andy"},
],
"id",
)
Expand Down Expand Up @@ -439,10 +445,10 @@ def test_diff(test_session, status_col):
)

expected = [
("M", fs1_updated, 1),
("A", fs2, 2),
("D", fs3, 3),
("S", fs4, 4),
(CompareStatus.MODIFIED, fs1_updated, 1),
(CompareStatus.ADDED, fs2, 2),
(CompareStatus.DELETED, fs3, 3),
(CompareStatus.SAME, fs4, 4),
]

collect_fields = ["diff", "file", "score"]
Expand Down Expand Up @@ -482,10 +488,10 @@ class Nested(BaseModel):
)

expected = [
("M", fs1_updated, 1),
("A", fs2, 2),
("D", fs3, 3),
("S", fs4, 4),
(CompareStatus.MODIFIED, fs1_updated, 1),
(CompareStatus.ADDED, fs2, 2),
(CompareStatus.DELETED, fs3, 3),
(CompareStatus.SAME, fs4, 4),
]

collect_fields = ["diff", "nested", "score"]
Expand Down
Loading
Loading