diff --git a/src/datachain/diff/__init__.py b/src/datachain/diff/__init__.py index 93451a66d..79390fca1 100644 --- a/src/datachain/diff/__init__.py +++ b/src/datachain/diff/__init__.py @@ -1,5 +1,3 @@ -import random -import string from collections.abc import Sequence from enum import Enum from typing import TYPE_CHECKING, Optional, Union @@ -11,16 +9,12 @@ if TYPE_CHECKING: from datachain.lib.dc import DataChain - C = Column -def get_status_col_name() -> str: - """Returns new unique status col name""" - return "diff_" + "".join( - random.choice(string.ascii_letters) # noqa: S311 - for _ in range(10) - ) +STATUS_COL_NAME = "diff_7aeed3aa17ba4d50b8d1c368c76e16a6" +LEFT_DIFF_COL_NAME = "diff_95f95344064a4b819c8625cd1a5cfc2b" +RIGHT_DIFF_COL_NAME = "diff_5808838a49b54849aa461d7387376d34" class CompareStatus(str, Enum): @@ -101,9 +95,9 @@ def _to_list(obj: Optional[Union[str, Sequence[str]]]) -> Optional[list[str]]: compare = right_compare = [c for c in cols if c in right_cols and c not in on] # type: ignore[misc] # get diff column names - diff_col = status_col or get_status_col_name() - ldiff_col = get_status_col_name() - rdiff_col = get_status_col_name() + diff_col = status_col or STATUS_COL_NAME + ldiff_col = LEFT_DIFF_COL_NAME + rdiff_col = RIGHT_DIFF_COL_NAME # adding helper diff columns, which will be removed after left = left.mutate(**{ldiff_col: 1}) @@ -227,7 +221,7 @@ def compare_and_split( ) ``` """ - status_col = get_status_col_name() + status_col = STATUS_COL_NAME res = _compare( left, diff --git a/src/datachain/hash_utils.py b/src/datachain/hash_utils.py new file mode 100644 index 000000000..f3d8efc73 --- /dev/null +++ b/src/datachain/hash_utils.py @@ -0,0 +1,147 @@ +import hashlib +import inspect +import json +import textwrap +from collections.abc import Sequence +from typing import TypeVar, Union + +from sqlalchemy.sql.elements import ( + BinaryExpression, + BindParameter, + ColumnElement, + Label, + Over, + UnaryExpression, +) +from sqlalchemy.sql.functions import Function + +T = TypeVar("T", bound=ColumnElement) +ColumnLike = Union[str, T] + + +def serialize_column_element(expr: Union[str, ColumnElement]) -> dict: # noqa: PLR0911 + """ + Recursively serialize a SQLAlchemy ColumnElement into a deterministic structure. + """ + + # Binary operations: col > 5, col1 + col2, etc. + if isinstance(expr, BinaryExpression): + op = ( + expr.operator.__name__ + if hasattr(expr.operator, "__name__") + else str(expr.operator) + ) + return { + "type": "binary", + "op": op, + "left": serialize_column_element(expr.left), + "right": serialize_column_element(expr.right), + } + + # Unary operations: -col, NOT col, etc. + if isinstance(expr, UnaryExpression): + op = ( + expr.operator.__name__ + if expr.operator is not None and hasattr(expr.operator, "__name__") + else str(expr.operator) + ) + + return { + "type": "unary", + "op": op, + "element": serialize_column_element(expr.element), # type: ignore[arg-type] + } + + # Function calls: func.lower(col), func.count(col), etc. + if isinstance(expr, Function): + return { + "type": "function", + "name": expr.name, + "clauses": [serialize_column_element(c) for c in expr.clauses], + } + + # Window functions: func.row_number().over(partition_by=..., order_by=...) + if isinstance(expr, Over): + return { + "type": "window", + "function": serialize_column_element(expr.element), + "partition_by": [ + serialize_column_element(p) for p in getattr(expr, "partition_by", []) + ], + "order_by": [ + serialize_column_element(o) for o in getattr(expr, "order_by", []) + ], + } + + # Labeled expressions: col.label("alias") + if isinstance(expr, Label): + return { + "type": "label", + "name": expr.name, + "element": serialize_column_element(expr.element), + } + + # Bound values (constants) + if isinstance(expr, BindParameter): + return {"type": "bind", "value": expr.value} + + # Plain columns + if hasattr(expr, "name"): + return {"type": "column", "name": expr.name} + + # Fallback: stringify unknown nodes + return {"type": "other", "repr": str(expr)} + + +def hash_column_elements(columns: Sequence[ColumnLike]) -> str: + """ + Hash a list of ColumnElements deterministically, dialect agnostic. + Only accepts ordered iterables (like list or tuple). + """ + serialized = [serialize_column_element(c) for c in columns] + json_str = json.dumps(serialized, sort_keys=True) # stable JSON + return hashlib.sha256(json_str.encode("utf-8")).hexdigest() + + +def hash_callable(func): + """ + Calculate a hash from a callable. + Rules: + - Named functions (def) → use source code for stable, cross-version hashing + - Lambdas → use bytecode (deterministic in same Python runtime) + """ + if not callable(func): + raise TypeError("Expected a callable") + + # Determine if it is a lambda + is_lambda = func.__name__ == "" + + if not is_lambda: + # Try to get exact source of named function + try: + lines, _ = inspect.getsourcelines(func) + payload = textwrap.dedent("".join(lines)).strip() + except (OSError, TypeError): + # Fallback: bytecode if source not available + payload = func.__code__.co_code + else: + # For lambdas, fall back directly to bytecode + payload = func.__code__.co_code + + # Normalize annotations + annotations = { + k: getattr(v, "__name__", str(v)) for k, v in func.__annotations__.items() + } + + # Extras to distinguish functions with same code but different metadata + extras = { + "name": func.__name__, + "defaults": func.__defaults__, + "annotations": annotations, + } + + # Compute SHA256 + h = hashlib.sha256() + h.update(str(payload).encode() if isinstance(payload, str) else payload) + h.update(str(extras).encode()) + return h.hexdigest() diff --git a/src/datachain/lib/dc/datachain.py b/src/datachain/lib/dc/datachain.py index 8f59945dd..940a50e22 100644 --- a/src/datachain/lib/dc/datachain.py +++ b/src/datachain/lib/dc/datachain.py @@ -209,6 +209,14 @@ def __repr__(self) -> str: self.print_schema(file=file) return file.getvalue() + def hash(self) -> str: + """ + Calculates SHA hash of this chain. Hash calculation is fast and consistent. + It takes into account all the steps added to the chain and their inputs. + Order of the steps is important. + """ + return self._query.hash() + def _as_delta( self, on: Optional[Union[str, Sequence[str]]] = None, @@ -682,7 +690,7 @@ def save( # type: ignore[override] if job_id := os.getenv("DATACHAIN_JOB_ID"): catalog.metastore.create_checkpoint( - job_id, # type: ignore[arg-type] + job_id, _hash=hashlib.sha256( # TODO this will be replaced with self.hash() str(uuid4()).encode() ).hexdigest(), diff --git a/src/datachain/lib/signal_schema.py b/src/datachain/lib/signal_schema.py index 87d94ecf4..33a247697 100644 --- a/src/datachain/lib/signal_schema.py +++ b/src/datachain/lib/signal_schema.py @@ -1,4 +1,6 @@ import copy +import hashlib +import json import warnings from collections.abc import Iterator, Sequence from dataclasses import dataclass @@ -257,6 +259,11 @@ def serialize(self) -> dict[str, Any]: signals["_custom_types"] = custom_types return signals + def hash(self) -> str: + """Create SHA hash of this schema""" + json_str = json.dumps(self.serialize(), sort_keys=True, separators=(",", ":")) + return hashlib.sha256(json_str.encode("utf-8")).hexdigest() + @staticmethod def _split_subtypes(type_name: str) -> list[str]: """This splits a list of subtypes, including proper square bracket handling.""" diff --git a/src/datachain/lib/udf.py b/src/datachain/lib/udf.py index 0d096a396..e3fac4254 100644 --- a/src/datachain/lib/udf.py +++ b/src/datachain/lib/udf.py @@ -1,3 +1,4 @@ +import hashlib import sys import traceback from collections.abc import Callable, Iterable, Iterator, Mapping, Sequence @@ -12,6 +13,7 @@ from datachain.asyn import AsyncMapper from datachain.cache import temporary_cache from datachain.dataset import RowDict +from datachain.hash_utils import hash_callable from datachain.lib.convert.flatten import flatten from datachain.lib.file import DataModel, File from datachain.lib.utils import AbstractUDF, DataChainError, DataChainParamsError @@ -61,6 +63,9 @@ class UDFAdapter: batch_size: Optional[int] = None batch: int = 1 + def hash(self) -> str: + return self.inner.hash() + def get_batching(self, use_partitioning: bool = False) -> BatchingStrategy: if use_partitioning: return Partition() @@ -151,6 +156,21 @@ def __init__(self): self.output = None self._func = None + def hash(self) -> str: + """ + Creates SHA hash of this UDF function. It takes into account function, + inputs and outputs. + """ + parts = [ + hash_callable(self._func), + self.params.hash() if self.params else "", + self.output.hash(), + ] + + return hashlib.sha256( + b"".join([bytes.fromhex(part) for part in parts]) + ).hexdigest() + def process(self, *args, **kwargs): """Processing function that needs to be defined by user""" if not self._func: diff --git a/src/datachain/query/dataset.py b/src/datachain/query/dataset.py index b414cfd53..a11c7a900 100644 --- a/src/datachain/query/dataset.py +++ b/src/datachain/query/dataset.py @@ -1,4 +1,5 @@ import contextlib +import hashlib import inspect import logging import os @@ -44,6 +45,7 @@ from datachain.dataset import DatasetDependency, DatasetStatus, RowDict from datachain.error import DatasetNotFoundError, QueryScriptCancelError from datachain.func.base import Function +from datachain.hash_utils import hash_column_elements from datachain.lib.listing import is_listing_dataset, listing_dataset_expired from datachain.lib.signal_schema import SignalSchema from datachain.lib.udf import UDFAdapter, _get_cache @@ -57,6 +59,7 @@ from datachain.utils import ( determine_processes, determine_workers, + ensure_sequence, filtered_cloudpickle_dumps, get_datachain_executable, safe_closing, @@ -167,6 +170,18 @@ def apply( ) -> "StepResult": """Apply the processing step.""" + @abstractmethod + def hash_inputs(self) -> str: + """Calculates hash of step inputs""" + + def hash(self) -> str: + """ + Calculates hash for step which includes step name and hash of it's inputs + """ + return hashlib.sha256( + f"{self.__class__.__name__}|{self.hash_inputs()}".encode() + ).hexdigest() + @frozen class QueryStep: @@ -186,6 +201,11 @@ def q(*columns): q, dr.columns, dependencies=[(self.dataset, self.dataset_version)] ) + def hash(self) -> str: + return hashlib.sha256( + self.dataset.uri(self.dataset_version).encode() + ).hexdigest() + def generator_then_call(generator, func: Callable): """ @@ -256,6 +276,13 @@ def q(*columns): class Subtract(DatasetDiffOperation): on: Sequence[tuple[str, str]] + def hash_inputs(self) -> str: + on_bytes = b"".join( + f"{a}:{b}".encode() for a, b in sorted(self.on, key=lambda t: (t[0], t[1])) + ) + + return hashlib.sha256(bytes.fromhex(self.dq.hash()) + on_bytes).hexdigest() + def query(self, source_query: Select, target_query: Select) -> sa.Selectable: sq = source_query.alias("source_query") tq = target_query.alias("target_query") @@ -393,6 +420,16 @@ class UDFStep(Step, ABC): min_task_size: Optional[int] = None batch_size: Optional[int] = None + def hash_inputs(self) -> str: + partition_by = ensure_sequence(self.partition_by or []) + parts = [ + bytes.fromhex(self.udf.hash()), + bytes.fromhex(hash_column_elements(partition_by)), + str(self.is_generator).encode(), + ] + + return hashlib.sha256(b"".join(parts)).hexdigest() + @abstractmethod def create_udf_table(self, query: Select) -> "Table": """Method that creates a table where temp udf results will be saved""" @@ -790,6 +827,9 @@ def apply_sql_clause(self, query): class SQLSelect(SQLClause): args: tuple[Union[Function, ColumnElement], ...] + def hash_inputs(self) -> str: + return hash_column_elements(self.args) + def apply_sql_clause(self, query) -> Select: subquery = query.subquery() args = [ @@ -806,6 +846,9 @@ def apply_sql_clause(self, query) -> Select: class SQLSelectExcept(SQLClause): args: tuple[Union[Function, ColumnElement], ...] + def hash_inputs(self) -> str: + return hash_column_elements(self.args) + def apply_sql_clause(self, query: Select) -> Select: subquery = query.subquery() args = [c for c in subquery.c if c.name not in set(self.parse_cols(self.args))] @@ -817,6 +860,9 @@ class SQLMutate(SQLClause): args: tuple[Label, ...] new_schema: SignalSchema + def hash_inputs(self) -> str: + return hash_column_elements(self.args) + def apply_sql_clause(self, query: Select) -> Select: original_subquery = query.subquery() to_mutate = {c.name for c in self.args} @@ -846,6 +892,9 @@ def apply_sql_clause(self, query: Select) -> Select: class SQLFilter(SQLClause): expressions: tuple[Union[Function, ColumnElement], ...] + def hash_inputs(self) -> str: + return hash_column_elements(self.expressions) + def __and__(self, other): expressions = self.parse_cols(self.expressions) return self.__class__(expressions + other) @@ -859,6 +908,9 @@ def apply_sql_clause(self, query: Select) -> Select: class SQLOrderBy(SQLClause): args: tuple[Union[Function, ColumnElement], ...] + def hash_inputs(self) -> str: + return hash_column_elements(self.args) + def apply_sql_clause(self, query: Select) -> Select: args = self.parse_cols(self.args) return query.order_by(*args) @@ -868,6 +920,9 @@ def apply_sql_clause(self, query: Select) -> Select: class SQLLimit(SQLClause): n: int + def hash_inputs(self) -> str: + return hashlib.sha256(str(self.n).encode()).hexdigest() + def apply_sql_clause(self, query: Select) -> Select: return query.limit(self.n) @@ -876,12 +931,18 @@ def apply_sql_clause(self, query: Select) -> Select: class SQLOffset(SQLClause): offset: int + def hash_inputs(self) -> str: + return hashlib.sha256(str(self.offset).encode()).hexdigest() + def apply_sql_clause(self, query: "GenerativeSelect"): return query.offset(self.offset) @frozen class SQLCount(SQLClause): + def hash_inputs(self) -> str: + return "" + def apply_sql_clause(self, query): return sqlalchemy.select(f.count(1)).select_from(query.subquery()) @@ -891,6 +952,9 @@ class SQLDistinct(SQLClause): args: tuple[ColumnElement, ...] dialect: str + def hash_inputs(self) -> str: + return hash_column_elements(self.args) + def apply_sql_clause(self, query): if self.dialect == "sqlite": return query.group_by(*self.args) @@ -903,6 +967,11 @@ class SQLUnion(Step): query1: "DatasetQuery" query2: "DatasetQuery" + def hash_inputs(self) -> str: + return hashlib.sha256( + bytes.fromhex(self.query1.hash()) + bytes.fromhex(self.query2.hash()) + ).hexdigest() + def apply( self, query_generator: QueryGenerator, temp_tables: list[str] ) -> StepResult: @@ -939,6 +1008,20 @@ class SQLJoin(Step): full: bool rname: str + def hash_inputs(self) -> str: + predicates = ensure_sequence(self.predicates or []) + + parts = [ + bytes.fromhex(self.query1.hash()), + bytes.fromhex(self.query2.hash()), + bytes.fromhex(hash_column_elements(predicates)), + str(self.inner).encode(), + str(self.full).encode(), + self.rname.encode("utf-8"), + ] + + return hashlib.sha256(b"".join(parts)).hexdigest() + def get_query(self, dq: "DatasetQuery", temp_tables: list[str]) -> sa.Subquery: query = dq.apply_steps().select() temp_tables.extend(dq.temp_table_names) @@ -1060,6 +1143,13 @@ class SQLGroupBy(SQLClause): cols: Sequence[Union[str, Function, ColumnElement]] group_by: Sequence[Union[str, Function, ColumnElement]] + def hash_inputs(self) -> str: + return hashlib.sha256( + bytes.fromhex( + hash_column_elements(self.cols) + hash_column_elements(self.group_by) + ) + ).hexdigest() + def apply_sql_clause(self, query) -> Select: if not self.cols: raise ValueError("No columns to select") @@ -1213,6 +1303,23 @@ def __iter__(self): def __or__(self, other): return self.union(other) + def hash(self) -> str: + """ + Calculates hash of this class taking into account hash of starting step + and hashes of each following steps. Ordering is important. + """ + hasher = hashlib.sha256() + if self.starting_step: + hasher.update(self.starting_step.hash().encode("utf-8")) + else: + assert self.list_ds_name + hasher.update(self.list_ds_name.encode("utf-8")) + + for step in self.steps: + hasher.update(step.hash().encode("utf-8")) + + return hasher.hexdigest() + @staticmethod def get_table() -> "TableClause": table_name = "".join( diff --git a/src/datachain/utils.py b/src/datachain/utils.py index 8fadd5abf..59fcba4db 100644 --- a/src/datachain/utils.py +++ b/src/datachain/utils.py @@ -537,3 +537,9 @@ def getenv_bool(name: str, default: bool = False) -> bool: if val is None: return default return val.lower() in ("1", "true", "yes", "on") + + +def ensure_sequence(x) -> Sequence: + if isinstance(x, Sequence) and not isinstance(x, (str, bytes)): + return x + return [x] diff --git a/tests/conftest.py b/tests/conftest.py index aa5c4d8e3..0496a7d95 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -25,7 +25,7 @@ SQLiteMetastore, SQLiteWarehouse, ) -from datachain.dataset import DatasetRecord +from datachain.dataset import DatasetRecord, DatasetVersion from datachain.lib.dc import Sys from datachain.namespace import Namespace from datachain.project import Project @@ -612,7 +612,25 @@ def dataset_record(): name=f"ds_{uuid.uuid4().hex}", description="", attrs=[], - versions=[], + versions=[ + DatasetVersion( + id=1, + uuid=uuid.uuid4().hex, + dataset_id=1, + version="1.0.0", + status=1, + created_at=datetime.now(), + finished_at=datetime.now(), + error_message="", + error_stack="", + num_objects=6, + size=100, + feature_schema=None, + script_output="", + schema=None, + _preview_data=[], + ) + ], status=1, schema={}, feature_schema={}, diff --git a/tests/unit/lib/test_diff.py b/tests/unit/lib/test_diff.py index 8257fb8b9..9d660448b 100644 --- a/tests/unit/lib/test_diff.py +++ b/tests/unit/lib/test_diff.py @@ -529,3 +529,44 @@ class Nested(BaseModel): collect_fields = collect_fields[1:] assert diff.order_by("nested.file.source").to_list(*collect_fields) == expected + + +def test_multiple_diffs(test_session): + ds1 = dc.read_values( + id=[1, 2, 3, 4, 5], + name=["John", "Doe", "Andy", "Matt", "Rick"], + session=test_session, + ) + ds2 = dc.read_values( + id=[1, 2, 3], + name=["John", "Doe", "Andy"], + session=test_session, + ) + ds3 = dc.read_values( + id=[1, 2], + name=["John", "Doe"], + session=test_session, + ) + + diff = ds1.diff( + ds2, + added=True, + deleted=False, + modified=False, + same=False, + on=["id"], + status_col="diff", + ).diff( + ds3, + added=True, + deleted=False, + modified=False, + same=False, + on=["id"], + status_col="diff", + ) + + assert diff.order_by("id").to_list(*["diff", "id", "name"]) == [ + ("A", 4, "Matt"), + ("A", 5, "Rick"), + ] diff --git a/tests/unit/lib/test_signal_schema.py b/tests/unit/lib/test_signal_schema.py index a7d8e94ab..69eeaecc5 100644 --- a/tests/unit/lib/test_signal_schema.py +++ b/tests/unit/lib/test_signal_schema.py @@ -1563,3 +1563,27 @@ def test_to_partial_complex_signal_error_invalid_field(): SignalSchemaError, match="Field nonexistent not found in custom type" ): schema.to_partial("file.nonexistent") + + +@pytest.mark.parametrize( + "schema,_hash", + [ + ( + { + "name": Optional[str], + "feature": Optional[MyType1], + }, + "73aa5b0c9e511027dc3aca0baea50b43a5451aad33f3261cc04c600649ff44ed", + ), + ( + {"file": File}, + "26a08b3793e738814f199c89c4582f9bde052ff3dcba84c2020535063df4c36c", + ), + ( + {}, + "44136fa355b3678a1146ad16f7e8649e94fb4fc21fe77e8310c060f61caaff8a", + ), + ], +) +def test_hash(schema, _hash): + assert SignalSchema(schema).hash() == _hash diff --git a/tests/unit/test_datachain_hash.py b/tests/unit/test_datachain_hash.py new file mode 100644 index 000000000..2446c559d --- /dev/null +++ b/tests/unit/test_datachain_hash.py @@ -0,0 +1,173 @@ +from unittest.mock import patch + +import pytest +from pydantic import BaseModel + +import datachain as dc +from datachain import func +from datachain.lib.dc import C + + +class Person(BaseModel): + name: str + age: float + + +class PersonAgg(BaseModel): + name: str + ages: float + + +class Player(BaseModel): + name: str + sport: str + + +class Worker(BaseModel): + name: str + age: float + title: str + + +persons = [ + Person(name="p1", age=10), + Person(name="p2", age=20), + Person(name="p3", age=30), + Person(name="p4", age=40), + Person(name="p5", age=40), + Person(name="p6", age=60), +] + + +players = [ + Player(name="p1", sport="baksetball"), + Player(name="p2", sport="soccer"), + Player(name="p3", sport="baseball"), + Player(name="p4", sport="tennis"), +] + + +@pytest.fixture +def mock_get_listing(): + with patch("datachain.lib.dc.storage.get_listing") as mock: + mock.return_value = ("lst__s3://my-bucket", "", "", True) + yield mock + + +def test_read_values(): + pytest.skip( + "Hash of the chain started with read_values is currently inconsistent," + " meaning it produces different hash every time. This happens because we" + " create random name dataset in the process. Correct solution would be" + " to calculate hash of all those input values." + ) + assert dc.read_values(num=[1, 2, 3]).hash() == "" + + +def test_read_storage(mock_get_listing): + assert dc.read_storage("s3://bucket").hash() == ( + "c38b6f4ebd7f0160d9f900016aad1e6781acd463f042588cfe793e9d189a8a0e" + ) + + +def test_read_dataset(test_session): + dc.read_values(num=[1, 2, 3], session=test_session).save("dev.animals.cats") + assert dc.read_dataset( + name="dev.animals.cats", version="1.0.0", session=test_session + ).hash() == ("51f2e5b81e40a22062a75c1590d0ccab880d182df9b39f610c6ccc503a5eb33c") + + +def test_order_of_steps(mock_get_listing): + assert ( + dc.read_storage("s3://bucket").mutate(new=10).filter(C("age") > 20).hash() + ) == "08a6c5657feaea55c734bc8e2b3eb0733ea692d4eab5fa78fa26409e6c2af098" + + assert ( + dc.read_storage("s3://bucket").filter(C("age") > 20).mutate(new=10).hash() + ) == "e91b84094233a2bf4d08d6a95e55529a65d900399be3a05dc3e2ca0401f8f25b" + + +def test_all_possible_steps(test_session): + persons_ds_name = "dev.my_pr.persons" + players_ds_name = "dev.my_pr.players" + + def map_worker(person: Person) -> Worker: + return Worker( + name=person.name, + age=person.age, + title="worker", + ) + + def gen_persons(person): + yield Person( + age=person.age * 2, + name=person.name + "_suf", + ) + + def agg_persons(persons): + return PersonAgg(ages=sum(p.age for p in persons), name=persons[0].age) + + dc.read_values(person=persons, session=test_session).save(persons_ds_name) + dc.read_values(player=players, session=test_session).save(players_ds_name) + + players_chain = dc.read_dataset( + players_ds_name, version="1.0.0", session=test_session + ) + + assert ( + dc.read_dataset(persons_ds_name, version="1.0.0", session=test_session) + .mutate(age_double=C("person.age") * 2) + .filter(C("person.age") > 20) + .order_by("person.name", "person.age") + .gen( + person=gen_persons, + output=Person, + ) + .map( + worker=map_worker, + params="person", + output={"worker": Worker}, + ) + .agg( + persons=agg_persons, + partition_by=C.person.name, + params="person", + output={"persons": PersonAgg}, + ) + .merge(players_chain, "persons.name", "player.name") + .distinct("persons.name") + .sample(10) + .offset(2) + .limit(5) + .group_by(age_avg=func.avg("persons.age"), partition_by="persons.name") + .select("persons.name", "age_avg") + .subtract( + players_chain, + on=["persons.name"], + right_on=["player.name"], + ) + .hash() + ) == "44b231652aee9712444ee26d5ecc77e6b87f768d17e6b8333303764d3706413b" + + +def test_diff(test_session): + persons_ds_name = "dev.my_pr.persons" + players_ds_name = "dev.my_pr.players" + + dc.read_values(person=persons, session=test_session).save(persons_ds_name) + dc.read_values(player=players, session=test_session).save(players_ds_name) + + players_chain = dc.read_dataset( + players_ds_name, version="1.0.0", session=test_session + ) + + assert ( + dc.read_dataset(persons_ds_name, version="1.0.0", session=test_session) + .diff( + players_chain, + on=["person.name"], + right_on=["player.name"], + status_col="diff", + ) + .hash() + ) == "aef929f3bf247966703534aa3daffb76fa8802d64660293deb95155ffacd8b77" diff --git a/tests/unit/test_hash_utils.py b/tests/unit/test_hash_utils.py new file mode 100644 index 000000000..d8b8d0bbd --- /dev/null +++ b/tests/unit/test_hash_utils.py @@ -0,0 +1,109 @@ +import pytest +import sqlalchemy as sa + +from datachain import C, func +from datachain.hash_utils import hash_callable, hash_column_elements + + +def double(x): + return x * 2 + + +def double_arg_annot(x: int): + return x * 2 + + +def double_arg_and_return_annot(x: int) -> int: + return x * 2 + + +lambda1 = lambda x: x * 2 # noqa: E731 +lambda2 = lambda y: y + 1 # noqa: E731 +lambda3 = lambda z: z - 1 # noqa: E731 + + +@pytest.mark.parametrize( + "expr,result", + [ + ( + [C("name")], + "8e95b415d0950727bb698f1a9fcaf28a4e088afe19d8256ecaa581022cde9365", + ), + ( + [C("name"), C("age")], + "c4f98a6350d621d16255490fe0d522b61749720a27f6a42b262090bad4100092", + ), + ( + [func.avg("age")], + "ddc23abe88c722954e568f7db548ddcbd060eed1a1a815bfcaabd1dce8add3aa", + ), + ( + [ + func.row_number().over( + func.window(partition_by="file.name", order_by="file.name") + ) + ], + "9da0e1581399e92f628c00879422835fc05ada2584e9962c0edb20f87637e8bf", + ), + ( + [C("age").label("user_age")], + "8a0a3d4e99972dc5fdc462b9981b309bbc6b0cc86d73880d56108ef0553bd426", + ), + ( + [C("age") > 20], + "6ba1c4384c710fe439e84749d7d08d675cb03d4c3683eb55bce11efd42372b67", + ), + ( + [sa.and_(C("age") > 20, C("name") != "")], + "a27c392ad1c294783ab70175478bf7cf2110fe559bf68504026f773e5aa361ab", + ), + ( + [], + "4f53cda18c2baa0c0354bb5f9a3ecbe5ed12ab4d8e11ba873c2f11161202b945", + ), + ], +) +def test_hash_column_elements(expr, result): + assert hash_column_elements(expr) == result + + +@pytest.mark.parametrize( + "func,expected_hash", + [ + (double, "aba077bec793c25e277923cde6905636a80595d1cb9a92a2c53432fc620d2f44"), + ( + double_arg_annot, + "391b2bfe41cfb76a9bb7e72c5ab4333f89124cd256d87cee93378739d078400f", + ), + ( + double_arg_and_return_annot, + "5f6c61c05d2c01a1b3745a69580cbf573ecdce2e09cce332cb83db0b270ff870", + ), + ], +) +def test_hash_named_functions(func, expected_hash): + h = hash_callable(func) + assert h == expected_hash + + +@pytest.mark.parametrize( + "func", + [ + lambda1, + lambda2, + lambda3, + ], +) +def test_lambda_same_hash(func): + h1 = hash_callable(func) + h2 = hash_callable(func) + assert h1 == h2 # same object produces same hash + + +def test_lambda_different_hashes(): + h1 = hash_callable(lambda1) + h2 = hash_callable(lambda2) + h3 = hash_callable(lambda3) + + # Ensure hashes are all different + assert len({h1, h2, h3}) == 3 diff --git a/tests/unit/test_query_steps_hash.py b/tests/unit/test_query_steps_hash.py new file mode 100644 index 000000000..e0d005767 --- /dev/null +++ b/tests/unit/test_query_steps_hash.py @@ -0,0 +1,505 @@ +import math +from dataclasses import replace + +import pytest +import sqlalchemy as sa +from pydantic import BaseModel + +import datachain as dc +from datachain import C, func +from datachain.func.func import Func +from datachain.lib.signal_schema import SignalSchema +from datachain.lib.udf import Aggregator, Generator, Mapper +from datachain.lib.udf_signature import UdfSignature +from datachain.query.dataset import ( + QueryStep, + RowGenerator, + SQLCount, + SQLDistinct, + SQLFilter, + SQLGroupBy, + SQLJoin, + SQLLimit, + SQLMutate, + SQLOffset, + SQLOrderBy, + SQLSelect, + SQLSelectExcept, + SQLUnion, + Subtract, + UDFSignal, +) + + +class CustomFeature(BaseModel): + sqrt: float + my_name: str + + +def double(x): + return x * 2 + + +def double2(y): + return 7 * 2 + + +def double_gen(x): + yield x * 2 + + +def double_gen_multi_arg(x, y): + yield x * 2 + yield y * 2 + + +def double_default(x, y=2): + return x * y + + +def double_kwonly(x, *, factor=3): + return x * factor + + +def map_custom_feature(m_fr): + return CustomFeature( + sqrt=math.sqrt(m_fr.count), + my_name=m_fr.nnn + "_suf", + ) + + +def custom_feature_gen(m_fr): + yield CustomFeature( + sqrt=math.sqrt(m_fr.count), + my_name=m_fr.nnn + "_suf", + ) + + +@pytest.fixture +def numbers_dataset(test_session): + """ + Fixture to create dataset with stable / constant UUID to have consistent + hash values in tests as it goes into chain hash calculation + """ + dc.read_values(num=list(range(100)), session=test_session).save("dev.num.numbers") + test_session.catalog.metastore.update_dataset_version( + test_session.catalog.get_dataset( + "numbers", namespace_name="dev", project_name="num" + ), + "1.0.0", + uuid="9045d46d-7c57-4442-aae3-3ca9e9f286c4", + ) + + return test_session.catalog.get_dataset( + "numbers", namespace_name="dev", project_name="num" + ) + + +@pytest.mark.parametrize( + "inputs,_hash", + [ + ( + (C("name"), C("age") * 10, func.avg("id"), C("country").label("country")), + "d03395827dcdddc2b2c3f0a3dafb71affa89c7f3b03b89e42734af2aea0e05ba", + ), + ((), "3245ba76bc1e4b1b1d4d775b88448ff02df9473bd919929166c70e9e2b245345"), + ( + (C("name"),), + "fe30656afd177ef32da191cc5ab3c68268282c382ef405d753e128b69767602f", + ), + ( + (func.rand().label("random"),), + "f99e28cd2023ae5a7855c72ffd44fc99e36442818d3855f46b3aed576ffc1d30", + ), + (("name",), "46eeec88c5f842bd478d3ec87032c49b22adcdd46572463b0acde4b2bac0900a"), + ], +) +def test_select_hash(inputs, _hash): + assert SQLSelect(inputs).hash() == _hash + + +@pytest.mark.parametrize( + "inputs,_hash", + [ + ( + (C("name"), C("age") * 10, func.avg("id"), C("country").label("country")), + "19894de08d545f3db85242be292dea0bb1ef47b0feaaf2c9359b159c7aa588c6", + ), + ((), "0d27e4cfa3801628afc535190c64a426d9db66e5145c57129b9f5ca0935ef29e"), + ( + (C("name"),), + "9515589e525bfa21cec0b68edf41c09e8df26e5c3023fd0775ba0ea02c9f6c8f", + ), + (("name",), "e26923a0433e549e680a4bcbc5cb95bb9a523c4b47ae23b07b2a928a609fc498"), + ], +) +def test_select_except_hash(inputs, _hash): + assert SQLSelectExcept(inputs).hash() == _hash + + +@pytest.mark.parametrize( + "inputs,_hash", + [ + ( + (sa.and_(C("name") != "John", C("age") * 10 > 100)), + "ba98f1a292cc7e95402899a43e5392708bcf448332e060becb24956fb531bfd0", + ), + ((), "19e718af35ddc311aa892756fa4f95413ce17db7c8b27f68200d9c3ce0fc8dbf"), + ( + (C("files.path").glob("*.jpg"),), + "c77898b24747f5106fd3793862d6c227e0423e096c6859ac95c27a9f7f7a824b", + ), + ( + sa.or_(C("age") > 50, C("country") == "US"), + "025880292c522fe7d3cf1163a11dc33b12c333e53d09efb12e40be08f31f95a2", + ), + ], +) +def test_filter_hash(inputs, _hash): + assert SQLFilter(inputs).hash() == _hash + + +@pytest.mark.parametrize( + "inputs,schema,_hash", + [ + ( + {"new_id": func.sum("id")}, + SignalSchema({"id": int}), + "d8e3af2fa2b5357643f80702455f0bbecb795b38bbb37eef24c644315e28617c", + ), + ( + {"new_id": C("id") * 10, "old_id": C("id")}, + SignalSchema({"id": int}), + "beea21224d3e2fae077a6a38d663fbaea0549fd38508b48fac3454cd76eca0df", + ), + ( + {}, + SignalSchema({"id": int}), + "b9717325e70a10ccd55c7faa22d5099ac8d5726d1a3c0eb3cfb001c7f628ce7f", + ), + ], +) +def test_mutate_hash(inputs, schema, _hash): + # transforming input into format SQLMutate expects + inputs = ( + v.label(k).get_column(schema) if isinstance(v, Func) else v.label(k) + for k, v in inputs.items() + ) + assert SQLMutate(inputs, new_schema=None).hash() == _hash + + +@pytest.mark.parametrize( + "inputs,_hash", + [ + ( + (C("name"), C("age")), + "8368b3239fd66422c18d561d2b61dbbae9fd88f9c935f67719b0d12ada50ffb6", + ), + (("name",), "b3562b4508052e5a57bc84ae862255939df294eb079e124c5af61fc21044343e"), + ( + (sa.desc(C("name")),), + "fd91c8cfe480debf1cdcf2b3f91462393a75042d0752a813ecc65dfed1ac7a6c", + ), + ((), "c525013178ef24a807af6d4dd44d108c20a5224eb3ab88b84c55c635ec32ba04"), + ], +) +def test_order_by_hash(inputs, _hash): + assert SQLOrderBy(inputs).hash() == _hash + + +@pytest.mark.parametrize( + "inputs,_hash", + [ + (5, "9fc462c7b5fe66106c8056b9f361817523de5c9f8d4e4b847e79cb02feba1351"), + (0, "1da7ad424bfdb853e852352fbb853722eb5fdc119592a778679aa00ba29f971a"), + ], +) +def test_limit_hash(inputs, _hash): + assert SQLLimit(inputs).hash() == _hash + + +@pytest.mark.parametrize( + "inputs,_hash", + [ + (5, "ff65be6bef149f6f2568f33c2bd0ac3362018a504caadf52c221a2e64acc5bb3"), + (0, "e88121711a1fa5da46ea2305e0d6fbeebe63f5b575450c628e7bf6f81e73aa46"), + ], +) +def test_offset_hash(inputs, _hash): + assert SQLOffset(inputs).hash() == _hash + + +@pytest.mark.parametrize( + "_hash", + [ + "8867973da58bd4d14c023fa9bad98dc50c18ba69240347216f7a8a1c7e70d377", + "8867973da58bd4d14c023fa9bad98dc50c18ba69240347216f7a8a1c7e70d377", + ], +) +def test_count_hash(_hash): + assert SQLCount().hash() == _hash + + +@pytest.mark.parametrize( + "inputs,_hash", + [ + (("name",), "bb0a1acba3bce39d31cc05dc01e57fc7265e451154187a6f93fbcf2001525c51"), + ( + ("name", "age"), + "29203756f44599f2728c70d75d92ff7af6110c8602e25839127c736d25a30c4b", + ), + ((), "7d4efeefbe9d1694bb89e7bf8b2d3f1d96ed0603e312b48d247d0ed3c881bf48"), + ], +) +def test_distinct_hash(inputs, _hash): + assert SQLDistinct(inputs, dialect=None).hash() == _hash + + +def test_union_hash(test_session, numbers_dataset): + chain1 = dc.read_dataset("dev.num.numbers").filter(C("num") > 50).limit(10) + chain2 = dc.read_dataset("dev.num.numbers").filter(C("num") < 50).limit(20) + + assert SQLUnion(chain1._query, chain2._query).hash() == ( + "c13c83192846342814d693740085494d509247bb3512af5966e66e2ed10bc8ad" + ) + + +@pytest.mark.parametrize( + "predicates,inner,full,rname,_hash", + [ + ( + "id", + True, + False, + "{name}_right", + "cd3504449c68fce0e6a687a7494b8a3ddb8e1b9b3452147c234c384fbbc201b2", + ), + ( + ("id", "name"), + False, + True, + "{name}_r", + "f637c82a2a197823ec5dc6614623c860d682110ceec60821759534a9e24ec6cf", + ), + ], +) +def test_join_hash( + test_session, numbers_dataset, predicates, inner, full, rname, _hash +): + chain1 = dc.read_dataset("dev.num.numbers").filter(C("num") > 50).limit(10) + chain2 = dc.read_dataset("dev.num.numbers").filter(C("num") < 50).limit(20) + + assert ( + SQLJoin( + test_session.catalog, + chain1._query, + chain2._query, + predicates, + inner, + full, + rname, + ).hash() + == _hash + ) + + +@pytest.mark.parametrize( + "columns,partition_by,_hash", + [ + ( + {"cnt": func.count(), "sum": func.sum("id")}, + [ + C("id"), + ], + "0f28ac6aa6daee1892d5e79b559c9c1c2072cec2d53d4e0f12c3ae42db1a869f", + ), + ( + {"cnt": func.count(), "sum": func.sum("id")}, + [C("id"), C("name")], + "f8ef71fc6d3438cd6905e0a4d96f9b13a465c4a955127d929837e3f0ac3d31d6", + ), + ( + {"cnt": func.count()}, + [], + "fe833a3ce997c919bcf3a2c5de1e76f2481a0937320f9fa0c2a8b3c191cea480", + ), + ], +) +def test_group_by_hash(columns, partition_by, _hash): + schema = SignalSchema({"id": int}) + # transforming inputs into format SQLGroupBy expects + columns = [v.get_column(schema, label=k) for k, v in columns.items()] + assert SQLGroupBy(columns, partition_by).hash() == _hash + + +@pytest.mark.parametrize( + "on,_hash", + [ + ( + [("id", "id")], + "4efcdbe669ea1c073bb12339f7bba79a78d61959988b12be975bffbf5dab0efd", + ), + ( + [("id", "id"), ("name", "name")], + "35553413a5a988fc8d3b73694881603f50143b1e1846a6d8748a6274519c64db", + ), + ( + [], + "9e9089070d5cfa3895ac03a53fd586149b84df49d0b2adbbe970fb6066e4b663", + ), + ], +) +def test_subtract_hash(test_session, numbers_dataset, on, _hash): + chain = dc.read_dataset("dev.num.numbers").filter(C("num") > 50).limit(20) + assert Subtract(chain._query, test_session.catalog, on).hash() == _hash + + +@pytest.mark.parametrize( + "func,params,output,_hash", + [ + ( + double, + ["x"], + {"double": int}, + "c62dcb3c110b1cadb47dd3b6499d7f4da351417fbe806a3e835237928a468708", + ), + ( + double2, + ["y"], + {"double": int}, + "674838e9557ad24b9fc68c6146b781e02fd7e0ad64361cc20c055f47404f0a95", + ), + ( + double_default, + ["x"], + {"double": int}, + "f25afd25ebb5f054bab721bea9126c5173c299abb0cbb3fd37d5687a7693a655", + ), + ( + double_kwonly, + ["x"], + {"double": int}, + "12f3620f703c541e0913c27cd828a8fe6e446f62f3d0b2a4ccfa5a1d9e2472e7", + ), + ( + map_custom_feature, + ["t1"], + {"x": CustomFeature}, + "b4edceaa18ed731085e1c433a6d21deabec8d92dfc338fb1d709ed7951977fc5", + ), + ], +) +def test_udf_mapper_hash( + func, + params, + output, + _hash, +): + sign = UdfSignature.parse("", {}, func, params, output, False) + udf_adapter = Mapper._create(sign, SignalSchema(sign.params)).to_udf_wrapper() + assert UDFSignal(udf_adapter, None).hash() == _hash + + +@pytest.mark.parametrize( + "func,params,output,_hash", + [ + ( + double_gen, + ["x"], + {"double": int}, + "c7ae1a50df841da2012c8422be87bfb29b101113030c43309ab6619011cdcc1c", + ), + ( + double_gen_multi_arg, + ["x", "y"], + {"double": int}, + "850352183532e057ec9c914bda906f15eb2223298e2cbd0c3585bf95a54e15e9", + ), + ( + custom_feature_gen, + ["t1"], + {"x": CustomFeature}, + "7ff702d242612cbb83cbd1777aa79d2792fb2a341db5ea406cd9fd3f42543b9c", + ), + ], +) +def test_udf_generator_hash( + func, + params, + output, + _hash, +): + sign = UdfSignature.parse("", {}, func, params, output, False) + udf_adapter = Generator._create(sign, SignalSchema(sign.params)).to_udf_wrapper() + assert RowGenerator(udf_adapter, None).hash() == _hash + + +@pytest.mark.parametrize( + "func,params,output,partition_by,_hash", + [ + ( + double_gen, + ["x"], + {"double": int}, + [C("x")], + "27f07777802865d1f78bba78edce4233cc1b155dbce1b0af3d1e93b290fba04e", + ), + ( + custom_feature_gen, + ["t1"], + {"x": CustomFeature}, + [C.t1.my_name], + "f3d2861f9c080529fe1ab33106c59f157e48ed6422dfb84c3e62e12b62db7fa7", + ), + ], +) +def test_udf_aggregator_hash( + func, + params, + output, + partition_by, + _hash, +): + sign = UdfSignature.parse("", {}, func, params, output, False) + udf_adapter = Aggregator._create(sign, SignalSchema(sign.params)).to_udf_wrapper() + assert RowGenerator(udf_adapter, None, partition_by=partition_by).hash() == _hash + + +@pytest.mark.parametrize( + "namespace_name,project_name,name,version,_hash", + [ + ( + "default", + "default", + "numbers", + "1.0.4", + "8173fb1d88df5cca3e904cbd17a9b80a0c8a682425c32cd95e32e1e196b7eff8", + ), + ( + "dev", + "animals", + "cats", + "1.0.1", + "e0aec7fe323ae3482ee2e74030a87ebb73dbb823ce970e15fdfcbd43e7abe2da", + ), + ( + "system", + "listing", + "lst__s3://bucket", + "1.0.1", + "19dff9f21030312c7469de7284cac2841063c22c62a7948a68f25ca018777c6d", + ), + ], +) +def test_query_step_hash( + dataset_record, namespace_name, project_name, name, version, _hash +): + namespace = replace(dataset_record.project.namespace, name=namespace_name) + project = dataset_record.project + project = replace(project, namespace=namespace) + project = replace(project, name=project_name) + dataset_record.project = project + dataset_record.name = name + dataset_record.versions[0].version = version + + assert QueryStep(None, dataset_record, version).hash() == _hash