diff --git a/src/datachain/lib/udf.py b/src/datachain/lib/udf.py index e3fac4254..4f429316d 100644 --- a/src/datachain/lib/udf.py +++ b/src/datachain/lib/udf.py @@ -160,9 +160,15 @@ def hash(self) -> str: """ Creates SHA hash of this UDF function. It takes into account function, inputs and outputs. + + For function-based UDFs, hashes self._func. + For class-based UDFs, hashes the process method. """ + # Hash user code: either _func (function-based) or process method (class-based) + func_to_hash = self._func if self._func else self.process + parts = [ - hash_callable(self._func), + hash_callable(func_to_hash), self.params.hash() if self.params else "", self.output.hash(), ] diff --git a/tests/unit/test_datachain_hash.py b/tests/unit/test_datachain_hash.py index 8950b77b2..076266847 100644 --- a/tests/unit/test_datachain_hash.py +++ b/tests/unit/test_datachain_hash.py @@ -1,5 +1,6 @@ from unittest.mock import patch +import pandas as pd import pytest from pydantic import BaseModel @@ -7,6 +8,11 @@ from datachain import func from datachain.lib.dc import C +DF_DATA = { + "first_name": ["Alice", "Bob", "Charlie", "David", "Eva"], + "age": [25, 30, 35, 40, 45], +} + class Person(BaseModel): name: str @@ -55,13 +61,55 @@ def mock_get_listing(): def test_read_values(): - pytest.skip( - "Hash of the chain started with read_values is currently inconsistent," - " meaning it produces different hash every time. This happens because we" - " create random name dataset in the process. Correct solution would be" - " to calculate hash of all those input values." + """ + Hash of the chain started with read_values is currently inconsistent. + Goal of this test is just to check it doesn't break. + """ + assert dc.read_values(num=[1, 2, 3]).hash() is not None + + +def test_read_csv(test_session, tmp_dir): + """ + Hash of the chain started with read_csv is currently inconsistent. + Goal of this test is just to check it doesn't break. + """ + path = tmp_dir / "test.csv" + pd.DataFrame(DF_DATA).to_csv(path, index=False) + assert dc.read_csv(path.as_uri(), session=test_session).hash() is not None + + +@pytest.mark.filterwarnings("ignore::pydantic.warnings.PydanticDeprecatedSince20") +def test_read_json(test_session, tmp_dir): + """ + Hash of the chain started with read_json is currently inconsistent. + Goal of this test is just to check it doesn't break. + """ + path = tmp_dir / "test.jsonl" + dc.read_pandas(pd.DataFrame(DF_DATA), session=test_session).to_jsonl(path) + assert ( + dc.read_json(path.as_uri(), format="jsonl", session=test_session).hash() + is not None ) - assert dc.read_values(num=[1, 2, 3]).hash() == "" + + +def test_read_pandas(test_session, tmp_dir): + """ + Hash of the chain started with read_pandas is currently inconsistent. + Goal of this test is just to check it doesn't break. + """ + df = pd.DataFrame(DF_DATA) + assert dc.read_pandas(df, session=test_session).hash() is not None + + +def test_read_parquet(test_session, tmp_dir): + """ + Hash of the chain started with read_parquet is currently inconsistent. + Goal of this test is just to check it doesn't break. + """ + df = pd.DataFrame(DF_DATA) + path = tmp_dir / "test.parquet" + dc.read_pandas(df, session=test_session).to_parquet(path) + assert dc.read_parquet(path.as_uri(), session=test_session).hash() is not None def test_read_storage(mock_get_listing): diff --git a/tests/unit/test_query_steps_hash.py b/tests/unit/test_query_steps_hash.py index 4fc1e56dd..260336c78 100644 --- a/tests/unit/test_query_steps_hash.py +++ b/tests/unit/test_query_steps_hash.py @@ -75,6 +75,22 @@ def custom_feature_gen(m_fr): ) +# Class-based UDFs for testing hash calculation +class DoubleMapper(Mapper): + """Class-based Mapper that overrides process().""" + + def process(self, x): + return x * 2 + + +class TripleGenerator(Generator): + """Class-based Generator that overrides process().""" + + def process(self, x): + yield x * 3 + yield x * 3 + 1 + + @pytest.fixture def numbers_dataset(test_session): """ @@ -394,6 +410,12 @@ def test_subtract_hash(test_session, numbers_dataset, on, _hash): {"x": CustomFeature}, "b4edceaa18ed731085e1c433a6d21deabec8d92dfc338fb1d709ed7951977fc5", ), + ( + DoubleMapper(), + ["x"], + {"double": int}, + "7994436106fef0486b04078b02ee437be3aa73ade2d139fb8c020e2199515e26", + ), ], ) def test_udf_mapper_hash( @@ -428,6 +450,12 @@ def test_udf_mapper_hash( {"x": CustomFeature}, "7ff702d242612cbb83cbd1777aa79d2792fb2a341db5ea406cd9fd3f42543b9c", ), + ( + TripleGenerator(), + ["x"], + {"triple": int}, + "02b4c6bf98ffa011b7c62f3374f219f21796ece5b001d99e4c2f69edf0a94f4a", + ), ], ) def test_udf_generator_hash(