From 62475ab48cb0f9c93fc12e2ba4350ab343e63150 Mon Sep 17 00:00:00 2001 From: Ivan Shcheklein Date: Sat, 6 Dec 2025 18:25:27 -0800 Subject: [PATCH 1/2] revise and simplify to_partial --- src/datachain/lib/signal_schema.py | 68 +++--- tests/unit/lib/test_signal_schema.py | 348 ++++++++++++--------------- 2 files changed, 191 insertions(+), 225 deletions(-) diff --git a/src/datachain/lib/signal_schema.py b/src/datachain/lib/signal_schema.py index c77a4ee61..c83a4ecf7 100644 --- a/src/datachain/lib/signal_schema.py +++ b/src/datachain/lib/signal_schema.py @@ -1218,6 +1218,28 @@ def to_partial(self, *columns: str) -> "SignalSchema": # noqa: C901 signal_partials: dict[str, str] = {} partial_versions: dict[str, int] = {} + def _ensure_partial_custom_type(type_name: str) -> None: + nonlocal data_model_bases + + if ( + type_name in custom_types + or type_name in schema_custom_types + or type_name in NAMES_TO_TYPES + or "Partial" not in type_name + ): + return + + if data_model_bases is None: + data_model_bases = SignalSchema._get_bases(DataModel) + + parsed_name, _ = ModelStore.parse_name_version(type_name) + schema_custom_types[type_name] = CustomType( + schema_version=2, + name=parsed_name, + fields={}, + bases=[(parsed_name, "__main__", type_name), *data_model_bases], + ) + def _type_name_to_partial(signal_name: str, type_name: str) -> str: # Check if we need to create a partial for this type # Only create partials for custom types that are in the custom_types dict @@ -1273,44 +1295,22 @@ def _type_name_to_partial(signal_name: str, type_name: str) -> str: f"Field {signal} not found in custom type {parent_type}" ) - # Check if this is the last part and if the column type is a complex - is_last_part = i == len(column_parts) - 1 - is_complex_signal = signal_type in custom_types - - if is_last_part and is_complex_signal: - schema[column] = signal_type - # Also need to remove the partial schema entry we created for the - # parent since we're promoting the nested complex column to root - parent_signal = column_parts[0] - schema.pop(parent_signal, None) - # Don't create partial types for this case - break - - # Create partial type for this field - partial_type = _type_name_to_partial( - ".".join(column_parts[: i + 1]), - signal_type, - ) + is_leaf = i == len(column_parts) - 1 - if parent_type_partial in schema_custom_types: - schema_custom_types[parent_type_partial].fields[signal] = ( - partial_type - ) + if is_leaf and signal_type in custom_types: + # Selecting an entire nested complex field: keep the original type + partial_type = signal_type else: - if data_model_bases is None: - data_model_bases = SignalSchema._get_bases(DataModel) - - partial_type_name, _ = ModelStore.parse_name_version(partial_type) - schema_custom_types[parent_type_partial] = CustomType( - schema_version=2, - name=partial_type_name, - fields={signal: partial_type}, - bases=[ - (partial_type_name, "__main__", partial_type), - *data_model_bases, - ], + partial_type = _type_name_to_partial( + ".".join(column_parts[: i + 1]), + signal_type, ) + _ensure_partial_custom_type(partial_type) + _ensure_partial_custom_type(parent_type_partial) + + schema_custom_types[parent_type_partial].fields[signal] = partial_type + parent_type, parent_type_partial = signal_type, partial_type if schema_custom_types: diff --git a/tests/unit/lib/test_signal_schema.py b/tests/unit/lib/test_signal_schema.py index 2327dcd09..40eb022ec 100644 --- a/tests/unit/lib/test_signal_schema.py +++ b/tests/unit/lib/test_signal_schema.py @@ -1502,66 +1502,49 @@ def test_column_types(column_type, signal_type): def test_to_partial(): schema = SignalSchema({"name": str, "age": float, "f": File}) partial = schema.to_partial("name", "f.path") - assert partial.serialize() == { - "_custom_types": { - "FilePartial1@v1": { - "bases": [ - ("FilePartial1", "datachain.lib.signal_schema", "FilePartial1@v1"), - ("DataModel", "datachain.lib.data_model", "DataModel@v1"), - ("BaseModel", "pydantic.main", None), - ("object", "builtins", None), - ], - "fields": { - "path": "str", - }, - "hidden_fields": [], - "name": "FilePartial1@v1", - "schema_version": 2, - }, - }, - "name": "str", - "f": "FilePartial1@v1", - } + assert set(partial.values) == {"name", "f"} + assert partial.values["name"] is str + + file_partial = partial.values["f"] + assert issubclass(file_partial, DataModel) + assert file_partial.__name__.startswith("FilePartial") + assert set(file_partial.model_fields) == {"path"} + assert file_partial.model_fields["path"].annotation is str + + serialized = partial.serialize() + assert serialized["name"] == "str" + assert serialized["f"] == ModelStore.get_name(file_partial) + assert ModelStore.get_name(file_partial) in serialized["_custom_types"] def test_to_partial_duplicate(): schema = SignalSchema({"name": str, "age": float, "f1": File, "f2": File}) partial = schema.to_partial("age", "f1.path", "f2.source") - assert partial.serialize() == { - "_custom_types": { - "FilePartial1@v1": { - "bases": [ - ("FilePartial1", "datachain.lib.signal_schema", "FilePartial1@v1"), - ("DataModel", "datachain.lib.data_model", "DataModel@v1"), - ("BaseModel", "pydantic.main", None), - ("object", "builtins", None), - ], - "fields": { - "path": "str", - }, - "hidden_fields": [], - "name": "FilePartial1@v1", - "schema_version": 2, - }, - "FilePartial2@v1": { - "bases": [ - ("FilePartial2", "datachain.lib.signal_schema", "FilePartial2@v1"), - ("DataModel", "datachain.lib.data_model", "DataModel@v1"), - ("BaseModel", "pydantic.main", None), - ("object", "builtins", None), - ], - "fields": { - "source": "str", - }, - "hidden_fields": [], - "name": "FilePartial2@v1", - "schema_version": 2, - }, - }, - "age": "float", - "f1": "FilePartial1@v1", - "f2": "FilePartial2@v1", - } + assert set(partial.values) == {"age", "f1", "f2"} + assert partial.values["age"] is float + + f1_partial = partial.values["f1"] + f2_partial = partial.values["f2"] + + assert issubclass(f1_partial, DataModel) + assert issubclass(f2_partial, DataModel) + assert f1_partial is not f2_partial + + assert f1_partial.__name__.startswith("FilePartial") + assert f2_partial.__name__.startswith("FilePartial") + + assert set(f1_partial.model_fields) == {"path"} + assert f1_partial.model_fields["path"].annotation is str + + assert set(f2_partial.model_fields) == {"source"} + assert f2_partial.model_fields["source"].annotation is str + + serialized = partial.serialize() + assert serialized["age"] == "float" + assert serialized["f1"] == ModelStore.get_name(f1_partial) + assert serialized["f2"] == ModelStore.get_name(f2_partial) + assert ModelStore.get_name(f1_partial) in serialized["_custom_types"] + assert ModelStore.get_name(f2_partial) in serialized["_custom_types"] def test_to_partial_nested(): @@ -1571,59 +1554,32 @@ class Custom(DataModel): schema = SignalSchema({"name": str, "age": float, "f": File, "custom": Custom}) partial = schema.to_partial("name", "f.path", "custom.file.source") - assert partial.serialize() == { - "_custom_types": { - "FilePartial1@v1": { - "bases": [ - ("FilePartial1", "datachain.lib.signal_schema", "FilePartial1@v1"), - ("DataModel", "datachain.lib.data_model", "DataModel@v1"), - ("BaseModel", "pydantic.main", None), - ("object", "builtins", None), - ], - "fields": { - "path": "str", - }, - "hidden_fields": [], - "name": "FilePartial1@v1", - "schema_version": 2, - }, - "FilePartial2@v1": { - "bases": [ - ("FilePartial2", "datachain.lib.signal_schema", "FilePartial2@v1"), - ("DataModel", "datachain.lib.data_model", "DataModel@v1"), - ("BaseModel", "pydantic.main", None), - ("object", "builtins", None), - ], - "fields": { - "source": "str", - }, - "hidden_fields": [], - "name": "FilePartial2@v1", - "schema_version": 2, - }, - "CustomPartial1@v1": { - "bases": [ - ( - "CustomPartial1", - "datachain.lib.signal_schema", - "CustomPartial1@v1", - ), - ("DataModel", "datachain.lib.data_model", "DataModel@v1"), - ("BaseModel", "pydantic.main", None), - ("object", "builtins", None), - ], - "fields": { - "file": "FilePartial2@v1", - }, - "hidden_fields": [], - "name": "CustomPartial1@v1", - "schema_version": 2, - }, - }, - "name": "str", - "f": "FilePartial1@v1", - "custom": "CustomPartial1@v1", - } + assert set(partial.values) == {"name", "f", "custom"} + assert partial.values["name"] is str + + f_partial = partial.values["f"] + assert issubclass(f_partial, DataModel) + assert set(f_partial.model_fields) == {"path"} + assert f_partial.model_fields["path"].annotation is str + assert f_partial.__name__.startswith("FilePartial") + + custom_partial = partial.values["custom"] + assert issubclass(custom_partial, DataModel) + assert set(custom_partial.model_fields) == {"file"} + assert custom_partial.__name__.startswith("CustomPartial") + + nested_file_partial = custom_partial.model_fields["file"].annotation + assert issubclass(nested_file_partial, DataModel) + assert nested_file_partial is not f_partial + assert set(nested_file_partial.model_fields) == {"source"} + assert nested_file_partial.model_fields["source"].annotation is str + assert nested_file_partial.__name__.startswith("FilePartial") + + serialized = partial.serialize() + assert serialized["name"] == "str" + assert serialized["f"] == ModelStore.get_name(f_partial) + assert serialized["custom"] == ModelStore.get_name(custom_partial) + assert ModelStore.get_name(nested_file_partial) in serialized["_custom_types"] def test_get_file_signal(): @@ -1638,10 +1594,6 @@ def test_to_partial_complex_signal_entire_file(): # Should return the entire File complex signal assert partial.values == {"file": File} - serialized = partial.serialize() - assert "file" in serialized - assert serialized["file"] == "File@v1" - assert "File@v1" in serialized["_custom_types"] def test_to_partial_complex_nested_signal(): @@ -1652,7 +1604,17 @@ class Custom(DataModel): schema = SignalSchema({"my_col": Custom, "name": str}) partial = schema.to_partial("my_col.src") - assert partial.values == {"my_col.src": File} + assert set(partial.values) == {"my_col"} + + custom_partial = partial.values["my_col"] + assert issubclass(custom_partial, DataModel) + assert set(custom_partial.model_fields) == {"src"} + assert custom_partial.model_fields["src"].annotation is File + assert custom_partial.__name__.startswith("CustomPartial") + + serialized = partial.serialize() + assert serialized["my_col"] == ModelStore.get_name(custom_partial) + assert "_custom_types" in serialized def test_to_partial_complex_deeply_nested_signal(): @@ -1676,9 +1638,20 @@ class Level3(DataModel): # Test deeply nested complex signal partial = schema.to_partial("deep.level2.level1.image") - # Should return the entire ImageFile complex signal with simplified name - assert "deep.level2.level1.image" in partial.values - assert partial.values["deep.level2.level1.image"] == ImageFile + deep_partial = partial.values["deep"] + level2_partial = deep_partial.model_fields["level2"].annotation + level1_partial = level2_partial.model_fields["level1"].annotation + + assert issubclass(level1_partial, DataModel) + assert set(level1_partial.model_fields) == {"image"} + assert level1_partial.model_fields["image"].annotation is ImageFile + assert deep_partial.__name__.startswith("Level3Partial") + assert level2_partial.__name__.startswith("Level2Partial") + assert level1_partial.__name__.startswith("Level1Partial") + + serialized = partial.serialize() + assert serialized["deep"] == ModelStore.get_name(deep_partial) + assert ModelStore.get_name(level1_partial) in serialized["_custom_types"] def test_to_partial_complex_nested_multiple_complex_signals(): @@ -1695,11 +1668,16 @@ class Container(DataModel): # Request multiple nested complex signals partial = schema.to_partial("container.file1", "container.file2") - # Should return both complex signals at root level - assert "container.file1" in partial.values - assert "container.file2" in partial.values - assert partial.values["container.file1"] == File - assert partial.values["container.file2"] == TextFile + assert set(partial.values) == {"container"} + + container_partial = partial.values["container"] + assert issubclass(container_partial, DataModel) + assert container_partial.model_fields["file1"].annotation is File + assert container_partial.model_fields["file2"].annotation is TextFile + assert container_partial.__name__.startswith("ContainerPartial") + + serialized = partial.serialize() + assert serialized["container"] == ModelStore.get_name(container_partial) def test_to_partial_complex_nested_mixed_complex_and_simple(): @@ -1715,14 +1693,19 @@ class Container(DataModel): # Request mix of nested complex signal and simple field partial = schema.to_partial("container.file", "container.name", "simple") - # Should have complex signal at root, partial for simple field, and simple type - assert "container.file" in partial.values - assert "container" in partial.values - assert "simple" in partial.values - - assert partial.values["container.file"] == File + assert set(partial.values) == {"container", "simple"} assert partial.values["simple"] is str + container_partial = partial.values["container"] + assert issubclass(container_partial, DataModel) + assert container_partial.model_fields["file"].annotation is File + assert container_partial.model_fields["name"].annotation is str + assert container_partial.__name__.startswith("ContainerPartial") + + serialized = partial.serialize() + assert serialized["container"] == ModelStore.get_name(container_partial) + assert serialized["simple"] == "str" + def test_to_partial_complex_nested_same_type_different_paths(): """Test to_partial with same complex type accessed via different nested paths.""" @@ -1740,12 +1723,22 @@ class Container2(DataModel): # Request same complex type from different nested paths partial = schema.to_partial("cont1.file", "cont2.file") - # Should return single File type at root level (deduplicated) - assert "cont1.file" in partial.values - assert "cont2.file" in partial.values - assert partial.values["cont1.file"] == File - assert partial.values["cont2.file"] == File - assert len(partial.values) == 2 + assert set(partial.values) == {"cont1", "cont2"} + + cont1_partial = partial.values["cont1"] + cont2_partial = partial.values["cont2"] + assert issubclass(cont1_partial, DataModel) + assert issubclass(cont2_partial, DataModel) + assert cont1_partial is not cont2_partial + + assert cont1_partial.model_fields["file"].annotation is File + assert cont2_partial.model_fields["file"].annotation is File + assert cont1_partial.__name__.startswith("Container1Partial") + assert cont2_partial.__name__.startswith("Container2Partial") + + serialized = partial.serialize() + assert serialized["cont1"] == ModelStore.get_name(cont1_partial) + assert serialized["cont2"] == ModelStore.get_name(cont2_partial) def test_to_partial_complex_signal_file_single_field(): @@ -1753,15 +1746,16 @@ def test_to_partial_complex_signal_file_single_field(): schema = SignalSchema({"name": str, "file": File}) partial = schema.to_partial("file.path") - serialized = partial.serialize() - assert "name" not in serialized # Only file should be included - assert "file" in serialized - assert serialized["file"] == "FilePartial1@v1" + assert set(partial.values) == {"file"} - # Check the partial type contains only path field - custom_types = serialized["_custom_types"] - file_partial = custom_types["FilePartial1@v1"] - assert file_partial["fields"] == {"path": "str"} + file_partial = partial.values["file"] + assert issubclass(file_partial, DataModel) + assert set(file_partial.model_fields) == {"path"} + assert file_partial.model_fields["path"].annotation is str + assert file_partial.__name__.startswith("FilePartial") + + serialized = partial.serialize() + assert serialized["file"] == ModelStore.get_name(file_partial) def test_to_partial_complex_signal_mixed_entire_and_fields(): @@ -1769,29 +1763,22 @@ def test_to_partial_complex_signal_mixed_entire_and_fields(): schema = SignalSchema({"file1": File, "file2": File, "name": str}) partial = schema.to_partial("file1", "file2.path", "name") - serialized = partial.serialize() - assert "file1" in serialized - assert "file2" in serialized - assert "name" in serialized + assert set(partial.values) == {"file1", "file2", "name"} - # file1 should be the entire File type - assert serialized["file1"] == "File@v1" - # file2 should be a partial with only path field - assert serialized["file2"].startswith("FilePartial") and serialized[ - "file2" - ].endswith("@v1") - # name should be simple type - assert serialized["name"] == "str" + assert partial.values["file1"] is File + assert partial.values["name"] is str - # Check custom types - custom_types = serialized["_custom_types"] - assert "File@v1" in custom_types - file2_partial_key = serialized["file2"] - assert file2_partial_key in custom_types + file2_partial = partial.values["file2"] + assert issubclass(file2_partial, DataModel) + assert set(file2_partial.model_fields) == {"path"} + assert file2_partial.model_fields["path"].annotation is str + assert file2_partial.__name__.startswith("FilePartial") - # Check partial has only path field - file2_partial = custom_types[file2_partial_key] - assert file2_partial["fields"] == {"path": "str"} + serialized = partial.serialize() + assert serialized["file1"] == "File@v1" + assert serialized["file2"] == ModelStore.get_name(file2_partial) + assert serialized["name"] == "str" + assert ModelStore.get_name(file2_partial) in serialized["_custom_types"] def test_to_partial_complex_signal_multiple_entire_files(): @@ -1799,24 +1786,13 @@ def test_to_partial_complex_signal_multiple_entire_files(): schema = SignalSchema({"file1": File, "file2": File, "name": str}) partial = schema.to_partial("file1", "file2") - serialized = partial.serialize() - assert "file1" in serialized - assert "file2" in serialized - assert "name" not in serialized # name was not requested - - # Both should be the entire File type - assert serialized["file1"] == "File@v1" - assert serialized["file2"] == "File@v1" - - # Should have the full File custom type - custom_types = serialized["_custom_types"] - assert "File@v1" in custom_types - assert len(custom_types) == 1 # Only one File type needed + assert set(partial.values) == {"file1", "file2"} + assert partial.values["file1"] is File + assert partial.values["file2"] is File def test_to_partial_complex_signal_nested_entire(): """Test to_partial with nested complex signal - entire parent.""" - from datachain.lib.data_model import DataModel class Container(DataModel): name: str @@ -1825,21 +1801,13 @@ class Container(DataModel): schema = SignalSchema({"container": Container, "simple": str}) partial = schema.to_partial("container") - serialized = partial.serialize() - assert "container" in serialized - assert "simple" not in serialized - - # Should be the entire Container type - assert serialized["container"].startswith("Container@") + assert set(partial.values) == {"container"} - # Should have the full Container custom type with nested File - custom_types = serialized["_custom_types"] - container_key = serialized["container"] - assert container_key in custom_types - assert "File@v1" in custom_types - - container_type = custom_types[container_key] - assert container_type["fields"] == {"name": "str", "file": "File@v1"} + container_type = partial.values["container"] + assert issubclass(container_type, DataModel) + assert set(container_type.model_fields) == {"name", "file"} + assert container_type.model_fields["name"].annotation is str + assert container_type.model_fields["file"].annotation is File def test_to_partial_complex_signal_empty_request(): @@ -1849,8 +1817,6 @@ def test_to_partial_complex_signal_empty_request(): # Should return empty schema assert partial.values == {} - serialized = partial.serialize() - assert "_custom_types" not in serialized or not serialized.get("_custom_types") def test_to_partial_complex_signal_error_invalid_signal(): From 8f78f83c50f75118cad1bfed42a8a6afd202bc4c Mon Sep 17 00:00:00 2001 From: Ivan Shcheklein Date: Tue, 9 Dec 2025 13:52:35 -0800 Subject: [PATCH 2/2] use fingerpints to reuse and disambiguate partials --- src/datachain/lib/data_model.py | 55 ++- src/datachain/lib/model_store.py | 39 +- src/datachain/lib/signal_schema.py | 430 ++++++++--------- src/datachain/lib/utils.py | 127 ++++- tests/func/test_datachain.py | 125 ++--- tests/func/test_signal_schema.py | 76 +++ tests/unit/lib/test_signal_schema.py | 462 +------------------ tests/unit/lib/test_utils.py | 130 +++++- tests/unit/test_data_model.py | 83 ++++ tests/unit/test_signal_schema_partials.py | 534 ++++++++++++++++++++++ 10 files changed, 1337 insertions(+), 724 deletions(-) create mode 100644 tests/func/test_signal_schema.py create mode 100644 tests/unit/test_data_model.py create mode 100644 tests/unit/test_signal_schema_partials.py diff --git a/src/datachain/lib/data_model.py b/src/datachain/lib/data_model.py index f5fd4d6d7..98118b218 100644 --- a/src/datachain/lib/data_model.py +++ b/src/datachain/lib/data_model.py @@ -1,3 +1,4 @@ +import hashlib import inspect import types import uuid @@ -8,8 +9,9 @@ from pydantic import AliasChoices, BaseModel, Field, create_model from pydantic.fields import FieldInfo +from datachain import json from datachain.lib.model_store import ModelStore -from datachain.lib.utils import normalize_col_names +from datachain.lib.utils import normalize_col_names, type_to_str StandardType = ( type[int] @@ -52,6 +54,57 @@ def hidden_fields(cls) -> list[str]: return cls._hidden_fields +def compute_model_fingerprint( + model: type[BaseModel], selection: dict[str, "dict[str, object] | None"] +) -> str: + """ + Compute a deterministic fingerprint for a model given a selection subtree. + + Selection uses the same structure as SignalSchema.to_partial: a mapping from + field name -> nested selection dict or None (leaf). + """ + + def _fingerprint_tree( + model_type: type[BaseModel], sel: dict[str, "dict[str, object] | None"] + ) -> dict[str, object]: + tree: dict[str, object] = {} + for field_name, sub_sel in sorted(sel.items()): + if field_name not in model_type.model_fields: + raise ValueError( + f"Field {field_name} not found in {model_type.__name__}" + ) + + finfo = model_type.model_fields[field_name] + field_type = finfo.annotation + required = finfo.is_required() + entry: dict[str, object] = { + "type": type_to_str(field_type, register_pydantic=False), + "required": bool(required), + "default": None if required else repr(finfo.default), + } + + child_model = ModelStore.to_pydantic(field_type) + if sub_sel is not None: + if child_model is None: + raise ValueError( + f"Field {field_name} in {model_type.__name__} is not a model" + ) + entry["children"] = _fingerprint_tree( + child_model, + sub_sel, # type: ignore[arg-type] + ) + tree[field_name] = entry + + return tree + + payload = { + "model": ModelStore.get_name(model), + "selection": _fingerprint_tree(model, selection), + } + json_str = json.dumps(payload, sort_keys=True, separators=(",", ":")) + return hashlib.sha256(json_str.encode("utf-8")).hexdigest() + + def is_chain_type(t: type) -> bool: """Return true if type is supported by `DataChain`.""" if ModelStore.is_pydantic(t): diff --git a/src/datachain/lib/model_store.py b/src/datachain/lib/model_store.py index 060b22f11..5abe8f652 100644 --- a/src/datachain/lib/model_store.py +++ b/src/datachain/lib/model_store.py @@ -7,6 +7,17 @@ class ModelStore: store: ClassVar[dict[str, dict[int, type[BaseModel]]]] = {} + @staticmethod + def _base_name(model: type[BaseModel]) -> str: + # Some models are generated/restored with a versioned Python class name + # (e.g. `MyType_v1`) so that multiple versions can coexist in-process without + # `__name__` collisions. + # + # `_modelstore_base_name` preserves the original/logical name (e.g. `MyType`) + # after we alter the class name, so `ModelStore.get_name(model)` still + # returns `"MyType@v{model._version}"` and schema serialization stays stable. + return getattr(model, "_modelstore_base_name", model.__name__) + @classmethod def get_version(cls, model: type[BaseModel]) -> int: if not hasattr(model, "_version"): @@ -15,9 +26,10 @@ def get_version(cls, model: type[BaseModel]) -> int: @classmethod def get_name(cls, model) -> str: + base_name = cls._base_name(model) if (version := cls.get_version(model)) > 0: - return f"{model.__name__}@v{version}" - return model.__name__ + return f"{base_name}@v{version}" + return base_name @classmethod def register(cls, fr: type): @@ -25,11 +37,18 @@ def register(cls, fr: type): if (model := ModelStore.to_pydantic(fr)) is None: return - name = model.__name__ - if name not in cls.store: - cls.store[name] = {} + base_name = cls._base_name(model) + unique_name = model.__name__ version = ModelStore.get_version(model) - cls.store[name][version] = model + + # Register under both: + # - `base_name` (logical/original name, from `_modelstore_base_name` when set) + # so callers can resolve by original name + version, e.g. + # `ModelStore.get("Foo", 1)` after deserialization/restore. + # - `unique_name` (Python class `__name__`, e.g. `Foo_v1`) so callers can also + # resolve by the runtime class name. + for name in (base_name, unique_name): + cls.store.setdefault(name, {})[version] = model for f_info in model.model_fields.values(): if (anno := ModelStore.to_pydantic(f_info.annotation)) is not None: @@ -62,8 +81,12 @@ def parse_name_version(cls, fullname: str) -> tuple[str, int]: @classmethod def remove(cls, fr: type) -> None: version = fr._version # type: ignore[attr-defined] - if fr.__name__ in cls.store and version in cls.store[fr.__name__]: - del cls.store[fr.__name__][version] + base_name = cls._base_name(fr) + unique_name = fr.__name__ + + for name in (base_name, unique_name): + if name in cls.store and version in cls.store[name]: + del cls.store[name][version] @staticmethod def is_pydantic(val: Any) -> bool: diff --git a/src/datachain/lib/signal_schema.py b/src/datachain/lib/signal_schema.py index c83a4ecf7..14572f533 100644 --- a/src/datachain/lib/signal_schema.py +++ b/src/datachain/lib/signal_schema.py @@ -11,10 +11,8 @@ from typing import ( IO, TYPE_CHECKING, - Annotated, Any, Final, - Literal, Optional, Union, get_args, @@ -23,7 +21,6 @@ from pydantic import BaseModel, Field, ValidationError, create_model from sqlalchemy import ColumnElement -from typing_extensions import Literal as LiteralEx from datachain import json from datachain.func import literal @@ -31,10 +28,19 @@ from datachain.lib.convert.python_to_sql import python_to_sql from datachain.lib.convert.sql_to_python import sql_to_python from datachain.lib.convert.unflatten import unflatten_to_json_pos -from datachain.lib.data_model import DataModel, DataType, DataValue +from datachain.lib.data_model import ( + DataModel, + DataType, + DataValue, + compute_model_fingerprint, +) from datachain.lib.file import File from datachain.lib.model_store import ModelStore -from datachain.lib.utils import DataChainColumnError, DataChainParamsError +from datachain.lib.utils import ( + DataChainColumnError, + DataChainParamsError, + type_to_str, +) from datachain.query.schema import DEFAULT_DELIMITER, C, Column, ColumnMeta from datachain.sql.types import SQLType @@ -154,6 +160,7 @@ class CustomType(BaseModel): fields: dict[str, str] bases: list[tuple[str, str, str | None]] hidden_fields: list[str] | None = None + partial_fingerprint: str | None = None @classmethod def deserialize(cls, data: dict[str, Any], type_name: str) -> "CustomType": @@ -166,6 +173,7 @@ def deserialize(cls, data: dict[str, Any], type_name: str) -> "CustomType": "fields": data, "bases": [], "hidden_fields": [], + "partial_fingerprint": None, } return cls(**data) @@ -173,19 +181,39 @@ def deserialize(cls, data: dict[str, Any], type_name: str) -> "CustomType": def create_feature_model( name: str, - fields: Mapping[str, type | tuple[type, Any] | None], + fields: Mapping[str, Any], base: type | None = None, + *, + partial_fingerprint: str | None = None, + hidden_fields: list[str] | None = None, ) -> type[BaseModel]: """ - This gets or returns a dynamic feature model for use in restoring a model - from the custom_types stored within a serialized SignalSchema. This is useful - when using a custom feature model where the original definition is not available. - This happens in Studio and if a custom model is used in a dataset, then that dataset - is used in a DataChain in a separate script where that model is not declared. + Build and register a dynamic feature model so it can be resolved later by name. + + Used when the original definition is not available (e.g., Studio restores or + cross-process dataset loads) and when deriving partial models in + ``SignalSchema.to_partial``. + + Args: + name: Logical model name. If it includes a version suffix like ``@v1``, the + version is parsed into ``_version``. + fields: Mapping of field definitions for the model body. + base: Base class for the generated model (defaults to ``DataModel``). + partial_fingerprint: If set, store ``_partial_fingerprint`` metadata. + hidden_fields: If set, store ``_hidden_fields`` metadata. + + Notes: + - The generated Python class name is versioned (e.g. ``MyType_v1``) to avoid + collisions when multiple versions are loaded in one process. + - ``_modelstore_base_name`` preserves the original/logical name (e.g. + ``MyType``), and ``ModelStore.register()`` stores the model under both the + logical name and the runtime class name for robust lookups. """ - name = name.replace("@", "_") - return create_model( - name, + base_name, parsed_version = ModelStore.parse_name_version(name) + class_name = f"{base_name}_v{parsed_version}" if parsed_version > 0 else base_name + model_name = class_name.replace("@", "_") + model = create_model( + model_name, __base__=base or DataModel, # type: ignore[call-overload] # These are tuples for each field of: annotation, default (if any) **{ @@ -194,6 +222,17 @@ def create_feature_model( }, # type: ignore[arg-type] ) + model._version = parsed_version # type: ignore[attr-defined] + model._modelstore_base_name = base_name # type: ignore[attr-defined] + if partial_fingerprint is not None: + model._partial_fingerprint = partial_fingerprint # type: ignore[attr-defined] + if hidden_fields is not None: + model._hidden_fields = hidden_fields # type: ignore[attr-defined] + + ModelStore.register(model) + + return model + @dataclass class SignalSchema: @@ -250,7 +289,8 @@ def _get_bases(fr: type) -> list[tuple[str, str, str | None]]: model_store_name = ( ModelStore.get_name(base) if issubclass(base, DataModel) else None ) - bases.append((base.__name__, base.__module__, model_store_name)) + base_name = getattr(base, "_modelstore_base_name", base.__name__) + bases.append((base_name, base.__module__, model_store_name)) return bases @staticmethod @@ -278,8 +318,9 @@ def _serialize_custom_model( fields=fields, bases=bases, hidden_fields=getattr(fr, "_hidden_fields", []), + partial_fingerprint=getattr(fr, "_partial_fingerprint", None), ) - custom_types[version_name] = ct.model_dump() + custom_types[version_name] = ct.model_dump(exclude_none=True) return version_name @@ -346,13 +387,18 @@ def _deserialize_custom_type( """Given a type name like MyType@v1 gets a type from ModelStore or recreates it based on the information from the custom types dict that includes fields and bases.""" - model_name, version = ModelStore.parse_name_version(type_name) - fr = ModelStore.get(model_name, version) - if fr: - return fr + model_name, target_version = ModelStore.parse_name_version(type_name) if type_name in custom_types: - ct = CustomType.deserialize(custom_types[type_name], type_name) + try: + ct = CustomType.deserialize(custom_types[type_name], type_name) + except ValidationError as exc: + raise SignalSchemaError( + f"cannot deserialize custom type '{type_name}': {exc}" + ) from exc + + if fr := ModelStore.get(model_name, target_version): + return fr fields = { field_name: SignalSchema._resolve_type(field_type_str, custom_types) @@ -363,16 +409,22 @@ def _deserialize_custom_type( for base in ct.bases: _, _, model_store_name = base if model_store_name: - model_name, version = ModelStore.parse_name_version( + base_model_name, base_version = ModelStore.parse_name_version( model_store_name ) - base_model = ModelStore.get(model_name, version) + base_model = ModelStore.get(base_model_name, base_version) if base_model: break - return create_feature_model(type_name, fields, base=base_model) + return create_feature_model( + type_name, + fields, + base=base_model, + hidden_fields=ct.hidden_fields, + partial_fingerprint=ct.partial_fingerprint, + ) - return None + return ModelStore.get(model_name, target_version) @staticmethod def _resolve_type(type_name: str, custom_types: dict[str, Any]) -> type | None: # noqa: PLR0911 @@ -1075,82 +1127,20 @@ def __contains__(self, name: str): return name in self.values @staticmethod - def _type_to_str( # noqa: C901, PLR0911, PLR0912 + def _type_to_str( type_: type | None | types.EllipsisType, subtypes: list | None = None ) -> str: """Convert a type to a string-based representation.""" - if type_ is None: - return "NoneType" - if type_ is Ellipsis: - return "..." - origin = get_origin(type_) + def _warn(msg: str) -> None: + warnings.warn(msg, SignalSchemaWarning, stacklevel=2) - if origin in (Union, types.UnionType): - args = get_args(type_) - if len(args) == 2 and type(None) in args: - # This is an Optional type. - non_none_type = args[0] if args[1] is type(None) else args[1] - type_str = SignalSchema._type_to_str(non_none_type, subtypes) - return f"Optional[{type_str}]" - formatted_types = ", ".join( - SignalSchema._type_to_str(arg, subtypes) for arg in args - ) - return f"Union[{formatted_types}]" - if origin == Optional: - args = get_args(type_) - type_str = SignalSchema._type_to_str(args[0], subtypes) - return f"Optional[{type_str}]" - if origin is list: - args = get_args(type_) - if len(args) == 0: - return "list" - type_str = SignalSchema._type_to_str(args[0], subtypes) - return f"list[{type_str}]" - if origin is tuple: - args = get_args(type_) - if len(args) == 0: - return "tuple" - if len(args) == 2 and args[1] is Ellipsis: - inner = SignalSchema._type_to_str(args[0], subtypes) - return f"tuple[{inner}, ...]" - type_str = ", ".join( - SignalSchema._type_to_str(arg, subtypes) for arg in args - ) - return f"tuple[{type_str}]" - if origin is dict: - args = get_args(type_) - if len(args) == 0: - return "dict" - key_type = SignalSchema._type_to_str(args[0], subtypes) - if len(args) == 1: - return f"dict[{key_type}, Any]" - val_type = SignalSchema._type_to_str(args[1], subtypes) - return f"dict[{key_type}, {val_type}]" - if origin == Annotated: - args = get_args(type_) - return SignalSchema._type_to_str(args[0], subtypes) - if origin in (Literal, LiteralEx) or type_ in (Literal, LiteralEx): - return "Literal" - if Any in (origin, type_): - return "Any" - if Final in (origin, type_): - return "Final" - if subtypes is not None: - # Include this type in the list of all subtypes, if requested. - subtypes.append(type_) - if not hasattr(type_, "__name__"): - # This can happen for some third-party or custom types - warnings.warn( - f"Unable to determine name of type '{type_}'.", - SignalSchemaWarning, - stacklevel=2, - ) - return "Any" - if ModelStore.is_pydantic(type_): - ModelStore.register(type_) - return ModelStore.get_name(type_) - return type_.__name__ + return type_to_str( + type_, + subtypes, + warn_with=_warn, + register_pydantic=True, + ) @staticmethod def _build_tree_for_type( @@ -1176,147 +1166,171 @@ def _build_tree_for_model( return res - def to_partial(self, *columns: str) -> "SignalSchema": # noqa: C901 - """ - Convert the schema to a partial schema with only the specified columns. + def to_partial(self, *columns: str) -> "SignalSchema": # noqa: C901, PLR0915 + """Return a schema that contains only the requested signals. - E.g. if original schema is: + Selection syntax uses dot-separated paths for nested fields: - ``` - signal: Foo@v1 - name: str - value: float - count: int - ``` + - Top-level fields: ``"name"`` + - Nested fields: ``"person.age"`` - Then `to_partial("signal.name", "count")` will return a partial schema: + Selection merge rules: - ``` - signal: FooPartial@v1 + - If a parent is selected (e.g. ``"person"``), it wins over any nested + selections (``"person.age"`` is redundant). + - If only some nested fields are selected (e.g. ``"person.age"``), a + partial model is generated for that nested model. + - If the nested selection ends up including *all* fields of a model, the + original model type is reused (no new partial model is created). + + Example: + + class Person(DataModel): name: str - count: int - ``` + age: int + + schema = SignalSchema({"id": int, "person": Person}) + partial = schema.to_partial("id", "person.age") - Note that partial schema will have a different name for the custom types - (e.g. `FooPartial@v1` instead of `Foo@v1`) to avoid conflicts - with the original schema. + person_type = ModelStore.to_pydantic(partial.values["person"]) + assert person_type is not None + assert set(person_type.model_fields) == {"age"} Args: - *columns (str): The columns to include in the partial schema. + *columns: Signal names to include. Returns: - SignalSchema: The new partial schema. + A new ``SignalSchema`` restricted to the requested signals. """ - serialized = self.serialize() - custom_types = serialized.get("_custom_types", {}) + if not columns: + return SignalSchema({}) - schema: dict[str, Any] = {} - schema_custom_types: dict[str, CustomType] = {} + selections: dict[str, dict[str, Any] | None] = {} - data_model_bases: list[tuple[str, str, str | None]] | None = None + def _validate_and_split_path(column: str) -> list[str]: + parts = column.split(".") + if parts[0] not in self.tree: + raise SignalSchemaError(f"Column {column} not found in the schema") - signal_partials: dict[str, str] = {} - partial_versions: dict[str, int] = {} + curr_type, curr_tree = self.tree[parts[0]] - def _ensure_partial_custom_type(type_name: str) -> None: - nonlocal data_model_bases + for part in parts[1:]: + if curr_tree is None: + raise SignalSchemaError(f"Column {column} not found in the schema") - if ( - type_name in custom_types - or type_name in schema_custom_types - or type_name in NAMES_TO_TYPES - or "Partial" not in type_name - ): - return + node = curr_tree.get(part) + if node is None: + parent_model = ModelStore.to_pydantic(curr_type) + if parent_model is not None: + raise SignalSchemaError( + f"Field {part} not found in custom type " + f"{parent_model.__name__}" + ) + raise SignalSchemaError(f"Column {column} not found in the schema") - if data_model_bases is None: - data_model_bases = SignalSchema._get_bases(DataModel) + curr_type, curr_tree = node - parsed_name, _ = ModelStore.parse_name_version(type_name) - schema_custom_types[type_name] = CustomType( - schema_version=2, - name=parsed_name, - fields={}, - bases=[(parsed_name, "__main__", type_name), *data_model_bases], - ) + return parts - def _type_name_to_partial(signal_name: str, type_name: str) -> str: - # Check if we need to create a partial for this type - # Only create partials for custom types that are in the custom_types dict - if type_name not in custom_types: - return type_name + def _merge_selection(parts: list[str]) -> None: + curr: dict[str, dict[str, Any] | None] = selections + missing = object() + for idx, part in enumerate(parts): + is_last = idx == len(parts) - 1 + existing = curr.get(part, missing) - if "@" in type_name: - model_name, _ = ModelStore.parse_name_version(type_name) - else: - model_name = type_name + if existing is None: + return - if signal_name not in signal_partials: - partial_versions.setdefault(model_name, 0) - partial_versions[model_name] += 1 - version = partial_versions[model_name] - signal_partials[signal_name] = f"{model_name}Partial{version}" + if is_last: + curr[part] = None + return - return signal_partials[signal_name] + if existing is missing: + next_sel: dict[str, Any] = {} + curr[part] = next_sel + curr = next_sel + else: + curr = existing # type: ignore[assignment] for column in columns: - parent_type, parent_type_partial = "", "" - column_parts = column.split(".") - for i, signal in enumerate(column_parts): - if i == 0: - if signal not in serialized: - raise SignalSchemaError( - f"Column {column} not found in the schema" - ) - - parent_type = serialized[signal] - parent_type_partial = _type_name_to_partial(signal, parent_type) - - schema[signal] = parent_type_partial - - # If this is a complex signal without field specifier (just "file") - # and it's a custom type, include the entire complex signal - if len(column_parts) == 1 and parent_type in custom_types: - # Include the entire complex signal - no need to create partial - schema[signal] = parent_type - continue - - continue + if not isinstance(column, str): + raise SignalResolvingTypeError("to_partial()", column) + + column_parts = _validate_and_split_path(column) + _merge_selection(column_parts) + + def _build_partial_type( + base_type: Any, selection: dict[str, Any] | None, path: list[str] + ) -> Any: + if selection is None: + return base_type + + if not selection: # pragma: no cover + raise RuntimeError( + "Internal error in SignalSchema.to_partial(): " + f"empty selection for '{'.'.join(path)}'" + ) - if parent_type not in custom_types: - raise SignalSchemaError( - f"Custom type {parent_type} not found in the schema" - ) + model = ModelStore.to_pydantic(base_type) + assert model is not None, "Expected complex type to be a Pydantic model" - custom_type = custom_types[parent_type] - signal_type = custom_type["fields"].get(signal) - if not signal_type: - raise SignalSchemaError( - f"Field {signal} not found in custom type {parent_type}" - ) - - is_leaf = i == len(column_parts) - 1 + if set(selection.keys()) == set(model.model_fields.keys()) and all( + sub_selection is None for sub_selection in selection.values() + ): + return base_type - if is_leaf and signal_type in custom_types: - # Selecting an entire nested complex field: keep the original type - partial_type = signal_type + field_types: dict[str, Any] = {} + for field_name, sub_selection in selection.items(): + assert field_name in model.model_fields, ( + "Selection should match existing model fields" + ) + field_info = model.model_fields[field_name] + field_type = field_info.annotation + assert field_type is not None, "Model fields must be typed" + partial_type = _build_partial_type( + field_type, sub_selection, [*path, field_name] + ) + if field_info.is_required(): + field_types[field_name] = partial_type else: - partial_type = _type_name_to_partial( - ".".join(column_parts[: i + 1]), - signal_type, - ) - - _ensure_partial_custom_type(partial_type) - _ensure_partial_custom_type(parent_type_partial) + field_types[field_name] = (partial_type, field_info.default) - schema_custom_types[parent_type_partial].fields[signal] = partial_type + assert field_types, ( + f"Empty field set when building partial for {model.__name__}" + ) - parent_type, parent_type_partial = signal_type, partial_type + fingerprint = compute_model_fingerprint(model, selection) + base_name, _ = ModelStore.parse_name_version(ModelStore.get_name(model)) + base_partial_name = f"{base_name}Partial_{fingerprint[:10]}" + base_hidden_fields = getattr(model, "_hidden_fields", []) + + version = 1 + + existing = ModelStore.get(base_partial_name, version) + if existing is None: + partial_model = create_feature_model( + f"{base_partial_name}@v{version}", + field_types, + base=DataModel, + partial_fingerprint=fingerprint, + hidden_fields=[ + fname for fname in base_hidden_fields if fname in field_types + ], + ) + elif getattr(existing, "_partial_fingerprint", None) == fingerprint: + partial_model = existing # type: ignore[assignment] + else: + msg = ( + f"partial model name collision '{base_partial_name}@v{version}' " + "with a different fingerprint" + ) + raise SignalSchemaError(msg) + return partial_model - if schema_custom_types: - schema["_custom_types"] = { - type_name: ct.model_dump() - for type_name, ct in schema_custom_types.items() - } + new_values: dict[str, DataType] = {} + for signal, selection in selections.items(): + base_type = self.values[signal] + new_values[signal] = _build_partial_type(base_type, selection, [signal]) - return SignalSchema.deserialize(schema) + return SignalSchema(new_values) diff --git a/src/datachain/lib/utils.py b/src/datachain/lib/utils.py index 4cff2e48f..07701bc9e 100644 --- a/src/datachain/lib/utils.py +++ b/src/datachain/lib/utils.py @@ -1,10 +1,16 @@ import inspect import re +import types +import warnings from abc import ABC, abstractmethod -from collections.abc import Sequence +from collections.abc import Callable, Sequence from pathlib import PurePosixPath +from typing import Annotated, Any, Final, Literal, Union, get_args, get_origin +from typing import Literal as LiteralEx from urllib.parse import urlparse +from datachain.lib.model_store import ModelStore + class AbstractUDF(ABC): @abstractmethod @@ -173,3 +179,122 @@ def rebase_path( return f"{new_base_parsed.scheme}://{full_path}" # Regular path return str(PurePosixPath(new_base) / new_relative_path) + + +def type_to_str( # noqa: C901, PLR0911, PLR0912 + type_: type | None | types.EllipsisType, + subtypes: list | None = None, + *, + warn_with: Callable[[str], None] | None = None, + register_pydantic: bool = False, +) -> str: + """Convert a type to a string representation shared across schema code.""" + + if type_ is None: + return "NoneType" + if type_ is Ellipsis: + return "..." + + origin = get_origin(type_) + + if origin in (Union, types.UnionType): + args = get_args(type_) + if len(args) == 2 and type(None) in args: + non_none_type = args[0] if args[1] is type(None) else args[1] + type_str = type_to_str( + non_none_type, + subtypes, + warn_with=warn_with, + register_pydantic=register_pydantic, + ) + return f"Optional[{type_str}]" + formatted_types = ", ".join( + type_to_str( + arg, + subtypes, + warn_with=warn_with, + register_pydantic=register_pydantic, + ) + for arg in args + ) + return f"Union[{formatted_types}]" + if origin is list: + args = get_args(type_) + if len(args) == 0: + return "list" + type_str = type_to_str( + args[0], + subtypes, + warn_with=warn_with, + register_pydantic=register_pydantic, + ) + return f"list[{type_str}]" + if origin is dict: + args = get_args(type_) + if len(args) == 0: + return "dict" + key_type = type_to_str( + args[0], + subtypes, + warn_with=warn_with, + register_pydantic=register_pydantic, + ) + if len(args) == 1: + return f"dict[{key_type}, Any]" + val_type = type_to_str( + args[1], + subtypes, + warn_with=warn_with, + register_pydantic=register_pydantic, + ) + return f"dict[{key_type}, {val_type}]" + if origin is tuple: + args = get_args(type_) + if len(args) == 0: + return "tuple" + if len(args) == 2 and args[1] is Ellipsis: + inner = type_to_str( + args[0], + subtypes, + warn_with=warn_with, + register_pydantic=register_pydantic, + ) + return f"tuple[{inner}, ...]" + formatted_types = ", ".join( + type_to_str( + arg, + subtypes, + warn_with=warn_with, + register_pydantic=register_pydantic, + ) + for arg in args + ) + return f"tuple[{formatted_types}]" + if origin is Annotated: + args = get_args(type_) + return type_to_str( + args[0], + subtypes, + warn_with=warn_with, + register_pydantic=register_pydantic, + ) + if origin in (Literal, LiteralEx) or type_ in (Literal, LiteralEx): + return "Literal" + if Any in (origin, type_): + return "Any" + if Final in (origin, type_): + return "Final" + if subtypes is not None: + subtypes.append(type_) + if not hasattr(type_, "__name__"): + msg = f"Unable to determine name of type '{type_}'." + if warn_with is not None: + warn_with(msg) + else: + warnings.warn(msg, RuntimeWarning, stacklevel=2) + return "Any" + if ModelStore.is_pydantic(type_): + if register_pydantic: + ModelStore.register(type_) + return ModelStore.get_name(type_) + return type_.__name__ diff --git a/tests/func/test_datachain.py b/tests/func/test_datachain.py index 0206c139c..381e4569b 100644 --- a/tests/func/test_datachain.py +++ b/tests/func/test_datachain.py @@ -19,6 +19,7 @@ import datachain as dc from datachain import DataModel, func from datachain.dataset import DatasetDependencyType +from datachain.lib.data_model import compute_model_fingerprint from datachain.lib.file import File, ImageFile from datachain.lib.listing import LISTING_TTL, is_listing_dataset, parse_listing_uri from datachain.lib.tar import process_tar @@ -1081,14 +1082,18 @@ def file_info(file: File) -> FileInfo: .save("my-ds") ) + fp_file_info_path = compute_model_fingerprint(FileInfo, {"path": None}) + fi_partial_base = f"FileInfoPartial_{fp_file_info_path[:10]}" + fi_partial_name = f"{fi_partial_base}@v1" + assert ds.signals_schema.serialize() == { "_custom_types": { - "FileInfoPartial1@v1": { + fi_partial_name: { "bases": [ ( - "FileInfoPartial1", + fi_partial_base, "datachain.lib.signal_schema", - "FileInfoPartial1@v1", + fi_partial_name, ), ("DataModel", "datachain.lib.data_model", "DataModel@v1"), ("BaseModel", "pydantic.main", None), @@ -1096,11 +1101,12 @@ def file_info(file: File) -> FileInfo: ], "fields": {"path": "str"}, "hidden_fields": [], - "name": "FileInfoPartial1@v1", + "name": fi_partial_name, "schema_version": 2, + "partial_fingerprint": fp_file_info_path, } }, - "file_info": "FileInfoPartial1@v1", + "file_info": fi_partial_name, "cnt": "int", "sum": "int", "value": "int", @@ -1146,43 +1152,44 @@ def file_info(file: File) -> FileInfo: .save("my-ds") ) + fp_file_info_name = compute_model_fingerprint(FileInfo, {"name": None}) + fp_file_info_path = compute_model_fingerprint(FileInfo, {"path": None}) + fi_name_base = f"FileInfoPartial_{fp_file_info_name[:10]}" + fi_name_full = f"{fi_name_base}@v1" + fi_path_base = f"FileInfoPartial_{fp_file_info_path[:10]}" + fi_path_full = f"{fi_path_base}@v1" + assert ds.signals_schema.serialize() == { "_custom_types": { - "FileInfoPartial1@v1": { + fi_name_full: { "bases": [ - ( - "FileInfoPartial1", - "datachain.lib.signal_schema", - "FileInfoPartial1@v1", - ), + (fi_name_base, "datachain.lib.signal_schema", fi_name_full), ("DataModel", "datachain.lib.data_model", "DataModel@v1"), ("BaseModel", "pydantic.main", None), ("object", "builtins", None), ], "fields": {"name": "str"}, "hidden_fields": [], - "name": "FileInfoPartial1@v1", + "name": fi_name_full, "schema_version": 2, + "partial_fingerprint": fp_file_info_name, }, - "FileInfoPartial2@v1": { + fi_path_full: { "bases": [ - ( - "FileInfoPartial2", - "datachain.lib.signal_schema", - "FileInfoPartial2@v1", - ), + (fi_path_base, "datachain.lib.signal_schema", fi_path_full), ("DataModel", "datachain.lib.data_model", "DataModel@v1"), ("BaseModel", "pydantic.main", None), ("object", "builtins", None), ], "fields": {"path": "str"}, "hidden_fields": [], - "name": "FileInfoPartial2@v1", + "name": fi_path_full, "schema_version": 2, + "partial_fingerprint": fp_file_info_path, }, }, - "f1": "FileInfoPartial1@v1", - "f2": "FileInfoPartial2@v1", + "f1": fi_name_full, + "f2": fi_path_full, "cnt": "int", "sum": "int", } @@ -1209,6 +1216,7 @@ def test_group_by_signals_nested(cloud_test_catalog): class FileName(DataModel): name: str = "" + ext: str = "" class FileInfo(DataModel): path: str = "" @@ -1218,10 +1226,12 @@ def file_info(file: File) -> FileInfo: full_path = file.source.rstrip("/") + "/" + file.path rel_path = posixpath.relpath(full_path, src_uri) path_parts = rel_path.split("/", 1) + file_name = path_parts[1] if len(path_parts) > 1 else path_parts[0] return FileInfo( path=path_parts[0] if len(path_parts) > 1 else "", name=FileName( - name=path_parts[1] if len(path_parts) > 1 else path_parts[0], + name=file_name, + ext=posixpath.splitext(file_name)[1].lstrip("."), ), ) @@ -1237,59 +1247,61 @@ def file_info(file: File) -> FileInfo: .save("my-ds") ) + fp_file_info_name = compute_model_fingerprint(FileInfo, {"name": {"name": None}}) + fp_file_info_path = compute_model_fingerprint(FileInfo, {"path": None}) + fp_file_name = compute_model_fingerprint(FileName, {"name": None}) + + fi_name_base = f"FileInfoPartial_{fp_file_info_name[:10]}" + fi_name_full = f"{fi_name_base}@v1" + fi_path_base = f"FileInfoPartial_{fp_file_info_path[:10]}" + fi_path_full = f"{fi_path_base}@v1" + fname_base = f"FileNamePartial_{fp_file_name[:10]}" + fname_full = f"{fname_base}@v1" + assert ds.signals_schema.serialize() == { "_custom_types": { - "FileInfoPartial1@v1": { + fname_full: { "bases": [ - ( - "FileInfoPartial1", - "datachain.lib.signal_schema", - "FileInfoPartial1@v1", - ), + (fname_base, "datachain.lib.signal_schema", fname_full), ("DataModel", "datachain.lib.data_model", "DataModel@v1"), ("BaseModel", "pydantic.main", None), ("object", "builtins", None), ], - "fields": {"name": "FileNamePartial1@v1"}, + "fields": {"name": "str"}, "hidden_fields": [], - "name": "FileInfoPartial1@v1", + "name": fname_full, "schema_version": 2, + "partial_fingerprint": fp_file_name, }, - "FileInfoPartial2@v1": { + fi_name_full: { "bases": [ - ( - "FileInfoPartial2", - "datachain.lib.signal_schema", - "FileInfoPartial2@v1", - ), + (fi_name_base, "datachain.lib.signal_schema", fi_name_full), ("DataModel", "datachain.lib.data_model", "DataModel@v1"), ("BaseModel", "pydantic.main", None), ("object", "builtins", None), ], - "fields": {"path": "str"}, + "fields": {"name": fname_full}, "hidden_fields": [], - "name": "FileInfoPartial2@v1", + "name": fi_name_full, "schema_version": 2, + "partial_fingerprint": fp_file_info_name, }, - "FileNamePartial1@v1": { + fi_path_full: { "bases": [ - ( - "FileNamePartial1", - "datachain.lib.signal_schema", - "FileNamePartial1@v1", - ), + (fi_path_base, "datachain.lib.signal_schema", fi_path_full), ("DataModel", "datachain.lib.data_model", "DataModel@v1"), ("BaseModel", "pydantic.main", None), ("object", "builtins", None), ], - "fields": {"name": "str"}, + "fields": {"path": "str"}, "hidden_fields": [], - "name": "FileNamePartial1@v1", + "name": fi_path_full, "schema_version": 2, + "partial_fingerprint": fp_file_info_path, }, }, - "f1": "FileInfoPartial1@v1", - "f2": "FileInfoPartial2@v1", + "f1": fi_name_full, + "f2": fi_path_full, "cnt": "int", "sum": "int", } @@ -1329,26 +1341,27 @@ def process(file: File) -> BBox: .save("my-ds") ) + fp_bbox_title = compute_model_fingerprint(BBox, {"title": None}) + bbox_base = f"BBoxPartial_{fp_bbox_title[:10]}" + bbox_full = f"{bbox_base}@v1" + assert ds.signals_schema.serialize() == { "_custom_types": { - "BBoxPartial1@v1": { + bbox_full: { "bases": [ - ( - "BBoxPartial1", - "datachain.lib.signal_schema", - "BBoxPartial1@v1", - ), + (bbox_base, "datachain.lib.signal_schema", bbox_full), ("DataModel", "datachain.lib.data_model", "DataModel@v1"), ("BaseModel", "pydantic.main", None), ("object", "builtins", None), ], "fields": {"title": "str"}, "hidden_fields": [], - "name": "BBoxPartial1@v1", + "name": bbox_full, "schema_version": 2, + "partial_fingerprint": fp_bbox_title, } }, - "box": "BBoxPartial1@v1", + "box": bbox_full, "cnt": "int", "value": "list[int]", } diff --git a/tests/func/test_signal_schema.py b/tests/func/test_signal_schema.py new file mode 100644 index 000000000..d05918976 --- /dev/null +++ b/tests/func/test_signal_schema.py @@ -0,0 +1,76 @@ +import copy +import uuid + +import datachain as dc +from datachain import DataModel, func +from datachain.lib.model_store import ModelStore + + +def test_partial_collision_on_dataset_reload(test_session): + """ + Simulate two runs: + 1) Create and save a dataset whose schema includes a partial of Info + (partition by info.a). + 2) Reset the ModelStore, then create a different partial with the same + generated name (partition by info.b), and finally read the saved + dataset back. + + If partial names collide without structural checks, the dataset + deserialization will reuse the incompatible partial, causing a schema + mismatch. + """ + + class Info(DataModel): + a: int + b: str + + def make_chain(): + return dc.read_values(a=[1, 2], b=["x", "y"], session=test_session).map( + lambda a, b: Info(a=a, b=b), params=["a", "b"], output={"info": Info} + ) + + # Preserve and restore ModelStore across the test to avoid leaking state. + original_store = copy.deepcopy(ModelStore.store) + try: + # First run: build and save dataset using a partial on info.a. + ModelStore.store = {} + ds_name = f"partial-collision-{uuid.uuid4()}" + make_chain().group_by(cnt=func.count(), partition_by="info.a").save(ds_name) + + partials_run1 = [] + for name, versions in ModelStore.store.items(): + if name.startswith("InfoPartial_"): + partials_run1.extend(versions.values()) + assert len(partials_run1) == 2 + assert all(set(p.model_fields.keys()) == {"a"} for p in partials_run1) + + # Second run: reset registry and create a different partial with the + # same base name but a different structure (partition by info.b). + ModelStore.store = {} + make_chain().group_by(cnt=func.count(), partition_by="info.b") + + # Now read back the saved dataset; it should bring in the original + # partial definition and register it in ModelStore. + dc.read_dataset(ds_name, session=test_session) + + partials = { + name: model + for name, versions in ModelStore.store.items() + if name.startswith("InfoPartial_") + for model in versions.values() + } + + # There should be two distinct partial bases (info.a and info.b) registered + # after reading the dataset. + fields_by_base: dict[str, set[str]] = {} + for name, model in partials.items(): + base = name.removesuffix("_v1") + fields_by_base.setdefault(base, set()).update(model.model_fields.keys()) + + assert len(fields_by_base) == 2 + actual_fields = sorted( + tuple(sorted(fields)) for fields in fields_by_base.values() + ) + assert actual_fields == [("a",), ("b",)] + finally: + ModelStore.store = original_store diff --git a/tests/unit/lib/test_signal_schema.py b/tests/unit/lib/test_signal_schema.py index 40eb022ec..cdf274a17 100644 --- a/tests/unit/lib/test_signal_schema.py +++ b/tests/unit/lib/test_signal_schema.py @@ -2,12 +2,9 @@ from datetime import datetime from typing import ( Any, - Dict, Final, - List, - Literal, + ForwardRef, Optional, - Tuple, Union, get_args, get_origin, @@ -85,6 +82,8 @@ class MyTypeComplexOld(DataModel): def test_deserialize_basic(): + # Make this test robust to other tests clearing the global ModelStore. + ModelStore.register(File) stored = {"name": "str", "count": "int", "file": "File@v1"} signals = SignalSchema.deserialize(stored) @@ -139,6 +138,13 @@ def test_serialize_basic(): assert "File@v1" in signals["_custom_types"] +def test_serialize_warns_for_unresolved_forwardref(): + schema = SignalSchema({"x": ForwardRef("X")}) # type: ignore[arg-type] + + with pytest.warns(SignalSchemaWarning, match=r"Unable to determine name of type"): + assert schema.serialize()["x"] == "Any" + + def test_feature_schema_serialize_optional(): schema = { "name": str | None, @@ -742,7 +748,7 @@ def test_deserialize_restores_known_base_type(): signals = SignalSchema(schema).serialize() ModelStore.remove(MyType3) - # Seince MyType3 is removed, deserialization restores it + # Since MyType3 is removed, deserialization restores it # from the meta information stored in the schema, including the base type # that is still known - MyType1 deserialized_schema = SignalSchema.deserialize(signals) @@ -1039,82 +1045,6 @@ def test_build_tree(): ] -def test_print_types(): - mapping = { - int: "int", - float: "float", - None: "NoneType", - Ellipsis: "...", - MyType2: "MyType2@v1", - Any: "Any", - Literal: "Literal", - Final: "Final", - Optional[MyType2]: "Optional[MyType2@v1]", - Union[MyType2 | None]: "Optional[MyType2@v1]", - Optional[MyType2]: "Optional[MyType2@v1]", - MyType2 | None: "Optional[MyType2@v1]", - Union[str, int]: "Union[str, int]", - str | int: "Union[str, int]", - Union[str, int, bool]: "Union[str, int, bool]", - str | int | bool: "Union[str, int, bool]", - List: "list", - list: "list", - Tuple: "tuple", - tuple: "tuple", - List[bool]: "list[bool]", - list[bool]: "list[bool]", - List[bool | None]: "list[Optional[bool]]", - list[bool | None]: "list[Optional[bool]]", - List[int]: "list[int]", - list[int]: "list[int]", - Dict: "dict", - dict: "dict", - Dict[str, bool]: "dict[str, bool]", - dict[str, bool]: "dict[str, bool]", - Dict[str, int]: "dict[str, int]", - dict[str, int]: "dict[str, int]", - dict[str, MyType1 | None]: "dict[str, Optional[MyType1@v1]]", - dict[str, MyType1 | None]: "dict[str, Optional[MyType1@v1]]", - dict[str, MyType1 | None]: "dict[str, Optional[MyType1@v1]]", - dict[str, MyType1 | None]: "dict[str, Optional[MyType1@v1]]", - dict[str, MyType1 | None]: "dict[str, Optional[MyType1@v1]]", - Union[str, list[str]]: "Union[str, list[str]]", - Union[str, list[str]]: "Union[str, list[str]]", - str | list[str]: "Union[str, list[str]]", - str | list[str]: "Union[str, list[str]]", - Optional[Literal["x"]]: "Optional[Literal]", - Literal["x"] | None: "Optional[Literal]", - Optional[list[bytes]]: "Optional[list[bytes]]", - Optional[list[bytes]]: "Optional[list[bytes]]", - Union[list[bytes], None]: "Optional[list[bytes]]", - Union[list[bytes], None]: "Optional[list[bytes]]", - list[bytes] | None: "Optional[list[bytes]]", - list[bytes] | None: "Optional[list[bytes]]", - list[Any]: "list[Any]", - list[Any]: "list[Any]", - tuple[int, float]: "tuple[int, float]", - Tuple[int, float]: "tuple[int, float]", - tuple[int, ...]: "tuple[int, ...]", - Tuple[int, ...]: "tuple[int, ...]", - Optional[tuple[int, float]]: "Optional[tuple[int, float]]", - Optional[Tuple[int, float]]: "Optional[tuple[int, float]]", - Optional[tuple[int, ...]]: "Optional[tuple[int, ...]]", - Optional[Tuple[int, ...]]: "Optional[tuple[int, ...]]", - } - - for t, v in mapping.items(): - assert SignalSchema._type_to_str(t) == v - - # Test that unknown types are ignored, but raise a warning. - mapping_warnings = { - 5: "Any", - "UnknownType": "Any", - } - for t, v in mapping_warnings.items(): - with pytest.warns(SignalSchemaWarning): - assert SignalSchema._type_to_str(t) == v - - def test_resolve_types(): mapping = { "int": int, @@ -1171,32 +1101,6 @@ def test_resolve_types(): assert SignalSchema._resolve_type(s, {}) == t -def test_type_to_str_typing_module_vs_builtin_generics(): - """Test that typing.List/Dict and list/dict behave identically with get_origin(). - - This ensures Python 3.10+ compatibility where both old-style (typing.List) - and new-style (list[]) generic annotations are normalized by get_origin() - to the same built-in types. - """ - from typing import get_origin - - # Verify get_origin() normalizes both forms to built-in types - assert get_origin(List[int]) is list - assert get_origin(list[int]) is list - assert get_origin(Dict[str, int]) is dict - assert get_origin(dict[str, int]) is dict - - # Verify _type_to_str produces identical output for both forms - assert SignalSchema._type_to_str(List[int]) == SignalSchema._type_to_str(list[int]) - assert SignalSchema._type_to_str(Dict[str, int]) == SignalSchema._type_to_str( - dict[str, int] - ) - assert SignalSchema._type_to_str(List[str]) == "list[str]" - assert SignalSchema._type_to_str(list[str]) == "list[str]" - assert SignalSchema._type_to_str(Dict[str, bool]) == "dict[str, bool]" - assert SignalSchema._type_to_str(dict[str, bool]) == "dict[str, bool]" - - def test_resolve_types_errors(): bogus_types_messages = { "": r"cannot be empty", @@ -1278,7 +1182,7 @@ def test_row_to_objs_all_none_returns_none(): assert res == [None] -def test_row_to_objs_partial_none_raises(): +def test_row_to_objs_some_none_values_raises(): schema = SignalSchema({"fr": MyType2}) row = ("name", None, None) @@ -1297,7 +1201,7 @@ def test_row_to_objs_all_none_nested_collections(): assert res == [5, None, "tag"] -def test_row_to_objs_nested_collections_partial_data_raises(): +def test_row_to_objs_nested_collections_some_values_missing_raises(): schema = SignalSchema({"id": int, "complex": MyTypeComplex, "label": str}) row = (5, "component", ["bad"], {"key": "value"}, "tag") @@ -1499,346 +1403,6 @@ def test_column_types(column_type, signal_type): assert signals["val"] is signal_type -def test_to_partial(): - schema = SignalSchema({"name": str, "age": float, "f": File}) - partial = schema.to_partial("name", "f.path") - assert set(partial.values) == {"name", "f"} - assert partial.values["name"] is str - - file_partial = partial.values["f"] - assert issubclass(file_partial, DataModel) - assert file_partial.__name__.startswith("FilePartial") - assert set(file_partial.model_fields) == {"path"} - assert file_partial.model_fields["path"].annotation is str - - serialized = partial.serialize() - assert serialized["name"] == "str" - assert serialized["f"] == ModelStore.get_name(file_partial) - assert ModelStore.get_name(file_partial) in serialized["_custom_types"] - - -def test_to_partial_duplicate(): - schema = SignalSchema({"name": str, "age": float, "f1": File, "f2": File}) - partial = schema.to_partial("age", "f1.path", "f2.source") - assert set(partial.values) == {"age", "f1", "f2"} - assert partial.values["age"] is float - - f1_partial = partial.values["f1"] - f2_partial = partial.values["f2"] - - assert issubclass(f1_partial, DataModel) - assert issubclass(f2_partial, DataModel) - assert f1_partial is not f2_partial - - assert f1_partial.__name__.startswith("FilePartial") - assert f2_partial.__name__.startswith("FilePartial") - - assert set(f1_partial.model_fields) == {"path"} - assert f1_partial.model_fields["path"].annotation is str - - assert set(f2_partial.model_fields) == {"source"} - assert f2_partial.model_fields["source"].annotation is str - - serialized = partial.serialize() - assert serialized["age"] == "float" - assert serialized["f1"] == ModelStore.get_name(f1_partial) - assert serialized["f2"] == ModelStore.get_name(f2_partial) - assert ModelStore.get_name(f1_partial) in serialized["_custom_types"] - assert ModelStore.get_name(f2_partial) in serialized["_custom_types"] - - -def test_to_partial_nested(): - class Custom(DataModel): - foo: str - file: File - - schema = SignalSchema({"name": str, "age": float, "f": File, "custom": Custom}) - partial = schema.to_partial("name", "f.path", "custom.file.source") - assert set(partial.values) == {"name", "f", "custom"} - assert partial.values["name"] is str - - f_partial = partial.values["f"] - assert issubclass(f_partial, DataModel) - assert set(f_partial.model_fields) == {"path"} - assert f_partial.model_fields["path"].annotation is str - assert f_partial.__name__.startswith("FilePartial") - - custom_partial = partial.values["custom"] - assert issubclass(custom_partial, DataModel) - assert set(custom_partial.model_fields) == {"file"} - assert custom_partial.__name__.startswith("CustomPartial") - - nested_file_partial = custom_partial.model_fields["file"].annotation - assert issubclass(nested_file_partial, DataModel) - assert nested_file_partial is not f_partial - assert set(nested_file_partial.model_fields) == {"source"} - assert nested_file_partial.model_fields["source"].annotation is str - assert nested_file_partial.__name__.startswith("FilePartial") - - serialized = partial.serialize() - assert serialized["name"] == "str" - assert serialized["f"] == ModelStore.get_name(f_partial) - assert serialized["custom"] == ModelStore.get_name(custom_partial) - assert ModelStore.get_name(nested_file_partial) in serialized["_custom_types"] - - -def test_get_file_signal(): - assert SignalSchema({"name": str, "f": File}).get_file_signal() == "f" - assert SignalSchema({"name": str}).get_file_signal() is None - - -def test_to_partial_complex_signal_entire_file(): - """Test to_partial with entire complex signal requested.""" - schema = SignalSchema({"file": File, "name": str}) - partial = schema.to_partial("file") - - # Should return the entire File complex signal - assert partial.values == {"file": File} - - -def test_to_partial_complex_nested_signal(): - class Custom(DataModel): - src: File - type: str - - schema = SignalSchema({"my_col": Custom, "name": str}) - partial = schema.to_partial("my_col.src") - - assert set(partial.values) == {"my_col"} - - custom_partial = partial.values["my_col"] - assert issubclass(custom_partial, DataModel) - assert set(custom_partial.model_fields) == {"src"} - assert custom_partial.model_fields["src"].annotation is File - assert custom_partial.__name__.startswith("CustomPartial") - - serialized = partial.serialize() - assert serialized["my_col"] == ModelStore.get_name(custom_partial) - assert "_custom_types" in serialized - - -def test_to_partial_complex_deeply_nested_signal(): - """Test to_partial with deeply nested complex signals (3+ levels).""" - from datachain.lib.file import ImageFile - - class Level1(DataModel): - image: ImageFile - name: str - - class Level2(DataModel): - level1: Level1 - category: str - - class Level3(DataModel): - level2: Level2 - id: str - - schema = SignalSchema({"deep": Level3, "simple": str}) - - # Test deeply nested complex signal - partial = schema.to_partial("deep.level2.level1.image") - - deep_partial = partial.values["deep"] - level2_partial = deep_partial.model_fields["level2"].annotation - level1_partial = level2_partial.model_fields["level1"].annotation - - assert issubclass(level1_partial, DataModel) - assert set(level1_partial.model_fields) == {"image"} - assert level1_partial.model_fields["image"].annotation is ImageFile - assert deep_partial.__name__.startswith("Level3Partial") - assert level2_partial.__name__.startswith("Level2Partial") - assert level1_partial.__name__.startswith("Level1Partial") - - serialized = partial.serialize() - assert serialized["deep"] == ModelStore.get_name(deep_partial) - assert ModelStore.get_name(level1_partial) in serialized["_custom_types"] - - -def test_to_partial_complex_nested_multiple_complex_signals(): - """Test to_partial with multiple nested complex signals.""" - from datachain.lib.file import TextFile - - class Container(DataModel): - file1: File - file2: TextFile - name: str - - schema = SignalSchema({"container": Container, "simple": str}) - - # Request multiple nested complex signals - partial = schema.to_partial("container.file1", "container.file2") - - assert set(partial.values) == {"container"} - - container_partial = partial.values["container"] - assert issubclass(container_partial, DataModel) - assert container_partial.model_fields["file1"].annotation is File - assert container_partial.model_fields["file2"].annotation is TextFile - assert container_partial.__name__.startswith("ContainerPartial") - - serialized = partial.serialize() - assert serialized["container"] == ModelStore.get_name(container_partial) - - -def test_to_partial_complex_nested_mixed_complex_and_simple(): - """Test to_partial with mix of nested complex signals and simple fields.""" - - class Container(DataModel): - file: File - name: str - count: int - - schema = SignalSchema({"container": Container, "simple": str}) - - # Request mix of nested complex signal and simple field - partial = schema.to_partial("container.file", "container.name", "simple") - - assert set(partial.values) == {"container", "simple"} - assert partial.values["simple"] is str - - container_partial = partial.values["container"] - assert issubclass(container_partial, DataModel) - assert container_partial.model_fields["file"].annotation is File - assert container_partial.model_fields["name"].annotation is str - assert container_partial.__name__.startswith("ContainerPartial") - - serialized = partial.serialize() - assert serialized["container"] == ModelStore.get_name(container_partial) - assert serialized["simple"] == "str" - - -def test_to_partial_complex_nested_same_type_different_paths(): - """Test to_partial with same complex type accessed via different nested paths.""" - - class Container1(DataModel): - file: File - name: str - - class Container2(DataModel): - file: File - category: str - - schema = SignalSchema({"cont1": Container1, "cont2": Container2}) - - # Request same complex type from different nested paths - partial = schema.to_partial("cont1.file", "cont2.file") - - assert set(partial.values) == {"cont1", "cont2"} - - cont1_partial = partial.values["cont1"] - cont2_partial = partial.values["cont2"] - assert issubclass(cont1_partial, DataModel) - assert issubclass(cont2_partial, DataModel) - assert cont1_partial is not cont2_partial - - assert cont1_partial.model_fields["file"].annotation is File - assert cont2_partial.model_fields["file"].annotation is File - assert cont1_partial.__name__.startswith("Container1Partial") - assert cont2_partial.__name__.startswith("Container2Partial") - - serialized = partial.serialize() - assert serialized["cont1"] == ModelStore.get_name(cont1_partial) - assert serialized["cont2"] == ModelStore.get_name(cont2_partial) - - -def test_to_partial_complex_signal_file_single_field(): - """Test to_partial with File complex signal - single field.""" - schema = SignalSchema({"name": str, "file": File}) - partial = schema.to_partial("file.path") - - assert set(partial.values) == {"file"} - - file_partial = partial.values["file"] - assert issubclass(file_partial, DataModel) - assert set(file_partial.model_fields) == {"path"} - assert file_partial.model_fields["path"].annotation is str - assert file_partial.__name__.startswith("FilePartial") - - serialized = partial.serialize() - assert serialized["file"] == ModelStore.get_name(file_partial) - - -def test_to_partial_complex_signal_mixed_entire_and_fields(): - """Test to_partial with mix of entire complex signal and specific fields.""" - schema = SignalSchema({"file1": File, "file2": File, "name": str}) - partial = schema.to_partial("file1", "file2.path", "name") - - assert set(partial.values) == {"file1", "file2", "name"} - - assert partial.values["file1"] is File - assert partial.values["name"] is str - - file2_partial = partial.values["file2"] - assert issubclass(file2_partial, DataModel) - assert set(file2_partial.model_fields) == {"path"} - assert file2_partial.model_fields["path"].annotation is str - assert file2_partial.__name__.startswith("FilePartial") - - serialized = partial.serialize() - assert serialized["file1"] == "File@v1" - assert serialized["file2"] == ModelStore.get_name(file2_partial) - assert serialized["name"] == "str" - assert ModelStore.get_name(file2_partial) in serialized["_custom_types"] - - -def test_to_partial_complex_signal_multiple_entire_files(): - """Test to_partial with multiple entire complex signals.""" - schema = SignalSchema({"file1": File, "file2": File, "name": str}) - partial = schema.to_partial("file1", "file2") - - assert set(partial.values) == {"file1", "file2"} - assert partial.values["file1"] is File - assert partial.values["file2"] is File - - -def test_to_partial_complex_signal_nested_entire(): - """Test to_partial with nested complex signal - entire parent.""" - - class Container(DataModel): - name: str - file: File - - schema = SignalSchema({"container": Container, "simple": str}) - partial = schema.to_partial("container") - - assert set(partial.values) == {"container"} - - container_type = partial.values["container"] - assert issubclass(container_type, DataModel) - assert set(container_type.model_fields) == {"name", "file"} - assert container_type.model_fields["name"].annotation is str - assert container_type.model_fields["file"].annotation is File - - -def test_to_partial_complex_signal_empty_request(): - """Test to_partial with no columns requested.""" - schema = SignalSchema({"file": File, "name": str}) - partial = schema.to_partial() - - # Should return empty schema - assert partial.values == {} - - -def test_to_partial_complex_signal_error_invalid_signal(): - """Test to_partial with invalid signal name.""" - schema = SignalSchema({"file": File}) - - with pytest.raises( - SignalSchemaError, match="Column nonexistent not found in the schema" - ): - schema.to_partial("nonexistent") - - -def test_to_partial_complex_signal_error_invalid_field(): - """Test to_partial with invalid field in complex signal.""" - schema = SignalSchema({"file": File}) - - with pytest.raises( - SignalSchemaError, match="Field nonexistent not found in custom type" - ): - schema.to_partial("file.nonexistent") - - @pytest.mark.parametrize( "schema,_hash", [ diff --git a/tests/unit/lib/test_utils.py b/tests/unit/lib/test_utils.py index d1bb61d7a..425ece1ed 100644 --- a/tests/unit/lib/test_utils.py +++ b/tests/unit/lib/test_utils.py @@ -1,10 +1,29 @@ +import copy from collections.abc import Iterable +from typing import ( # noqa: UP035 + Annotated, + Any, + Dict, + Final, + List, + Literal, + Optional, + Tuple, + Union, +) import pytest from pydantic import BaseModel from datachain.lib.convert.python_to_sql import python_to_sql -from datachain.lib.utils import callable_name, normalize_col_names, rebase_path +from datachain.lib.data_model import DataModel +from datachain.lib.model_store import ModelStore +from datachain.lib.utils import ( + callable_name, + normalize_col_names, + rebase_path, + type_to_str, +) from datachain.sql.types import Array, String @@ -201,6 +220,115 @@ def method(self): assert callable_name(b.method) == "method" +@pytest.mark.parametrize( + "type_, expected", + [ + (int, "int"), + (float, "float"), + (None, "NoneType"), + (Ellipsis, "..."), + (Any, "Any"), + (Final[int], "Final"), + (Optional[int], "Optional[int]"), # noqa: UP045 + (int | str, {"Union[int, str]", "Union[str, int]"}), + ( + str | int | bool, + { + "Union[int, str, bool]", + "Union[str, int, bool]", + "Union[str, bool, int]", + "Union[int, bool, str]", + "Union[bool, str, int]", + "Union[bool, int, str]", + }, + ), + (Annotated[int, "meta"], "int"), + (list[Any], "list[Any]"), + (list[bool], "list[bool]"), + (List[bool], "list[bool]"), # noqa: UP006 + (list[bool | None], "list[Optional[bool]]"), + (List[bool | None], "list[Optional[bool]]"), # noqa: UP006 + (List[int], "list[int]"), # noqa: UP006 + (List[str], "list[str]"), # noqa: UP006 + (Optional[list[bytes]], "Optional[list[bytes]]"), # noqa: UP045 + (Literal["x"] | None, "Optional[Literal]"), + (tuple[int, float], "tuple[int, float]"), + (tuple[int, ...], "tuple[int, ...]"), + (Optional[tuple[int, float]], "Optional[tuple[int, float]]"), # noqa: UP045 + (dict[str], "dict[str, Any]"), # type: ignore[misc] + (dict[str, bool], "dict[str, bool]"), + (Dict[str, bool], "dict[str, bool]"), # noqa: UP006 + (dict[str, int], "dict[str, int]"), + (Dict[str, int], "dict[str, int]"), # noqa: UP006 + (Union[list[bytes], None], "Optional[list[bytes]]"), # noqa: UP007 + (Union[List[bytes], None], "Optional[list[bytes]]"), # noqa: UP006, UP007 + ], +) +def test_type_to_str_matrix(type_, expected): + result = type_to_str(type_) + if isinstance(expected, set): + assert result in expected + else: + assert result == expected + + +def test_type_to_str_typing_module_vs_builtin_generics(): + """Ensure typing.List/Dict and built-in generics stringify identically. + + Confirms Python 3.10+ behavior where get_origin() normalizes both forms + to built-ins, and type_to_str produces the same string. + """ + from typing import get_origin + + assert get_origin(List[int]) is list # noqa: UP006 + assert get_origin(list[int]) is list + assert get_origin(Dict[str, int]) is dict # noqa: UP006 + assert get_origin(dict[str, int]) is dict + + assert type_to_str(List[int]) == type_to_str(list[int]) # noqa: UP006 + assert type_to_str(Dict[str, int]) == type_to_str(dict[str, int]) # noqa: UP006 + assert type_to_str(List[str]) == "list[str]" # noqa: UP006 + assert type_to_str(list[str]) == "list[str]" + assert type_to_str(Dict[str, bool]) == "dict[str, bool]" # noqa: UP006 + assert type_to_str(dict[str, bool]) == "dict[str, bool]" + + +def test_type_to_str_warn_with_called_for_unknown(): + # Unknown types should fall back to Any but emit a warning via the callback. + calls: list[str] = [] + + def collect(msg: str) -> None: + calls.append(msg) + + result = type_to_str(object(), warn_with=collect) + assert result == "Any" + assert calls and "Unable to determine name" in calls[0] + + +def test_type_to_str_warns_without_callback(): + with pytest.warns(RuntimeWarning, match="Unable to determine name"): + assert type_to_str(object()) == "Any" + + +def test_type_to_str_empty_generics(): + assert type_to_str(List) == "list" # noqa: UP006 + assert type_to_str(Dict) == "dict" # noqa: UP006 + assert type_to_str(Tuple) == "tuple" # noqa: UP006 + + +def test_type_to_str_pydantic_model_uses_model_store(): + snapshot = copy.deepcopy(ModelStore.store) + ModelStore.store = {} + try: + + class Sample(DataModel): + a: int + + assert type_to_str(Sample) == ModelStore.get_name(Sample) + finally: + ModelStore.store = snapshot + + def test_callable_name_callable_instance(): class Foo: def __call__(self, x): diff --git a/tests/unit/test_data_model.py b/tests/unit/test_data_model.py new file mode 100644 index 000000000..7950289bb --- /dev/null +++ b/tests/unit/test_data_model.py @@ -0,0 +1,83 @@ +import copy + +import pytest + +from datachain.lib.data_model import DataModel, compute_model_fingerprint +from datachain.lib.model_store import ModelStore + + +@pytest.fixture(autouse=True) +def restore_model_store(): + snapshot = copy.deepcopy(ModelStore.store) + ModelStore.store = {} + try: + yield + finally: + ModelStore.store = snapshot + + +def test_compute_model_fingerprint_missing_field(): + class Sample(DataModel): + a: int + + with pytest.raises(ValueError, match="Field missing not found in Sample"): + compute_model_fingerprint(Sample, {"missing": None}) + + +def test_compute_model_fingerprint_non_model_child(): + class Sample(DataModel): + a: int + + with pytest.raises(ValueError, match="Field a in Sample is not a model"): + compute_model_fingerprint(Sample, {"a": {"child": None}}) + + +def test_compute_model_fingerprint_stable_for_same_selection(): + class Sample(DataModel): + a: int + b: int + + sel = {"a": None} + fp1 = compute_model_fingerprint(Sample, sel) + fp2 = compute_model_fingerprint(Sample, sel) + assert fp1 == fp2 + + +def test_compute_model_fingerprint_changes_with_selection(): + class Sample(DataModel): + a: int + b: int + + fp_a = compute_model_fingerprint(Sample, {"a": None}) + fp_b = compute_model_fingerprint(Sample, {"b": None}) + assert fp_a != fp_b + + +def test_compute_model_fingerprint_nested_model(): + class Child(DataModel): + x: int + y: int + + class Parent(DataModel): + child: Child + z: int + + fp_child_x = compute_model_fingerprint(Parent, {"child": {"x": None}}) + fp_child_y = compute_model_fingerprint(Parent, {"child": {"y": None}}) + fp_child_all = compute_model_fingerprint(Parent, {"child": {"x": None, "y": None}}) + + assert fp_child_x != fp_child_y + assert fp_child_all != fp_child_x + assert fp_child_all != fp_child_y + + +def test_compute_model_fingerprint_required_vs_optional_differs(): + class Required(DataModel): + value: int + + class OptionalField(DataModel): + value: int | None = None + + fp_required = compute_model_fingerprint(Required, {"value": None}) + fp_optional = compute_model_fingerprint(OptionalField, {"value": None}) + assert fp_required != fp_optional diff --git a/tests/unit/test_signal_schema_partials.py b/tests/unit/test_signal_schema_partials.py new file mode 100644 index 000000000..8106640aa --- /dev/null +++ b/tests/unit/test_signal_schema_partials.py @@ -0,0 +1,534 @@ +import copy + +import pytest +from pydantic import Field + +from datachain import DataModel +from datachain.lib.data_model import compute_model_fingerprint +from datachain.lib.file import File, TextFile +from datachain.lib.model_store import ModelStore +from datachain.lib.signal_schema import ( + SignalResolvingTypeError, + SignalSchema, + SignalSchemaError, + create_feature_model, +) + + +class Info(DataModel): + a: int + b: int + + +def _reset_model_store(): + ModelStore.store = {} + + +@pytest.fixture(autouse=True) +def _autoreset_model_store(): + snapshot = copy.deepcopy(ModelStore.store) + try: + ModelStore.store = {} + yield + finally: + ModelStore.store = snapshot + + +def test_partial_same_selection_reuses_name(): + schema = SignalSchema({"info": Info}) + + schema.to_partial("info.a") + selection = {"a": None} + fingerprint = compute_model_fingerprint(Info, selection) + base_partial_name = f"InfoPartial_{fingerprint[:10]}" + + names_after_first = set(ModelStore.store) + assert names_after_first == {base_partial_name, f"{base_partial_name}_v1"} + + schema.to_partial("info.a") + names_after_second = set(ModelStore.store) + + assert names_after_first == names_after_second + + +def test_partial_different_selection_differs(): + schema = SignalSchema({"info": Info}) + + schema.to_partial("info.a") + selection_a = {"a": None} + fingerprint_a = compute_model_fingerprint(Info, selection_a) + base_a = f"InfoPartial_{fingerprint_a[:10]}" + + schema.to_partial("info.b") + selection_b = {"b": None} + fingerprint_b = compute_model_fingerprint(Info, selection_b) + base_b = f"InfoPartial_{fingerprint_b[:10]}" + + names = {name for name in ModelStore.store if name.startswith("InfoPartial_")} + + assert names == {base_a, f"{base_a}_v1", base_b, f"{base_b}_v1"} + + +def test_partial_name_collision_disambiguates(): + schema = SignalSchema({"info": Info}) + + # Pre-register a conflicting model using the deterministic base name but + # wrong fingerprint + selection = {"a": None} + fingerprint = compute_model_fingerprint(Info, selection) + base_name, _ = ModelStore.parse_name_version(ModelStore.get_name(Info)) + colliding_name = f"{base_name}Partial_{fingerprint[:10]}@v1" + + rogue = create_feature_model( + colliding_name, + {"a": (int, None)}, + base=DataModel, + ) + rogue._partial_fingerprint = "wrong" # type: ignore[attr-defined] + ModelStore.register(rogue) + + with pytest.raises(SignalSchemaError, match="partial model name collision"): + schema.to_partial("info.a") + + +def test_partial_fingerprint_roundtrip_serialization(): + schema = SignalSchema({"info": Info}) + + partial_schema = schema.to_partial("info.a") + partial_model = ModelStore.to_pydantic(partial_schema.values["info"]) + orig_fp = getattr(partial_model, "_partial_fingerprint", None) + + serialized = partial_schema.serialize() + _reset_model_store() + + roundtrip = SignalSchema.deserialize(serialized) + rt_model = ModelStore.to_pydantic(roundtrip.values["info"]) + + assert getattr(rt_model, "_partial_fingerprint", None) == orig_fp + + +def test_to_partial(): + schema = SignalSchema({"name": str, "age": float, "f": File}) + partial = schema.to_partial("name", "f.path") + assert set(partial.values) == {"name", "f"} + assert partial.values["name"] is str + + file_partial = partial.values["f"] + assert issubclass(file_partial, DataModel) + assert file_partial.__name__.startswith("FilePartial") + assert set(file_partial.model_fields) == {"path"} + assert file_partial.model_fields["path"].annotation is str + + serialized = partial.serialize() + assert serialized["name"] == "str" + assert serialized["f"] == ModelStore.get_name(file_partial) + assert ModelStore.get_name(file_partial) in serialized["_custom_types"] + + +def test_to_partial_duplicate(): + schema = SignalSchema({"name": str, "age": float, "f1": File, "f2": File}) + partial = schema.to_partial("age", "f1.path", "f2.source") + assert set(partial.values) == {"age", "f1", "f2"} + assert partial.values["age"] is float + + f1_partial = partial.values["f1"] + f2_partial = partial.values["f2"] + + assert issubclass(f1_partial, DataModel) + assert issubclass(f2_partial, DataModel) + assert f1_partial is not f2_partial + + assert f1_partial.__name__.startswith("FilePartial") + assert f2_partial.__name__.startswith("FilePartial") + + assert set(f1_partial.model_fields) == {"path"} + assert f1_partial.model_fields["path"].annotation is str + + assert set(f2_partial.model_fields) == {"source"} + assert f2_partial.model_fields["source"].annotation is str + + serialized = partial.serialize() + assert serialized["age"] == "float" + assert serialized["f1"] == ModelStore.get_name(f1_partial) + assert serialized["f2"] == ModelStore.get_name(f2_partial) + assert ModelStore.get_name(f1_partial) in serialized["_custom_types"] + assert ModelStore.get_name(f2_partial) in serialized["_custom_types"] + + +def test_to_partial_multiple_calls_unique_partial_names(): + schema = SignalSchema({"file": File, "name": str}) + + partial1 = schema.to_partial("file.path") + partial2 = schema.to_partial("file.source") + + file_partial_1 = partial1.values["file"] + file_partial_2 = partial2.values["file"] + + # Each call should produce a distinct partial model to avoid name collisions + assert file_partial_1 is not file_partial_2 + assert file_partial_1.__name__ != file_partial_2.__name__ + + assert set(file_partial_1.model_fields) == {"path"} + assert file_partial_1.model_fields["path"].annotation is str + + assert set(file_partial_2.model_fields) == {"source"} + assert file_partial_2.model_fields["source"].annotation is str + + serialized_1 = partial1.serialize() + serialized_2 = partial2.serialize() + + assert serialized_1["file"] != serialized_2["file"] + assert serialized_1["file"] in serialized_1["_custom_types"] + assert serialized_2["file"] in serialized_2["_custom_types"] + + +def test_to_partial_nested(): + class Custom(DataModel): + foo: str + file: File + + schema = SignalSchema({"name": str, "age": float, "f": File, "custom": Custom}) + partial = schema.to_partial("name", "f.path", "custom.file.source") + assert set(partial.values) == {"name", "f", "custom"} + assert partial.values["name"] is str + + f_partial = partial.values["f"] + assert issubclass(f_partial, DataModel) + assert set(f_partial.model_fields) == {"path"} + assert f_partial.model_fields["path"].annotation is str + assert f_partial.__name__.startswith("FilePartial") + + custom_partial = partial.values["custom"] + assert issubclass(custom_partial, DataModel) + assert set(custom_partial.model_fields) == {"file"} + assert custom_partial.__name__.startswith("CustomPartial") + + nested_file_partial = custom_partial.model_fields["file"].annotation + assert issubclass(nested_file_partial, DataModel) + assert nested_file_partial is not f_partial + assert set(nested_file_partial.model_fields) == {"source"} + assert nested_file_partial.model_fields["source"].annotation is str + assert nested_file_partial.__name__.startswith("FilePartial") + + serialized = partial.serialize() + assert serialized["name"] == "str" + assert serialized["f"] == ModelStore.get_name(f_partial) + assert serialized["custom"] == ModelStore.get_name(custom_partial) + assert ModelStore.get_name(nested_file_partial) in serialized["_custom_types"] + + +def test_get_file_signal(): + assert SignalSchema({"name": str, "f": File}).get_file_signal() == "f" + assert SignalSchema({"name": str}).get_file_signal() is None + + +def test_to_partial_complex_signal_entire_file(): + """Test to_partial with entire complex signal requested.""" + schema = SignalSchema({"file": File, "name": str}) + partial = schema.to_partial("file") + + # Should return the entire File complex signal + assert partial.values == {"file": File} + + +def test_to_partial_complex_nested_signal(): + class Custom(DataModel): + src: File + type: str + + schema = SignalSchema({"my_col": Custom, "name": str}) + partial = schema.to_partial("my_col.src") + + assert set(partial.values) == {"my_col"} + + custom_partial = partial.values["my_col"] + assert issubclass(custom_partial, DataModel) + assert set(custom_partial.model_fields) == {"src"} + assert custom_partial.model_fields["src"].annotation is File + assert custom_partial.__name__.startswith("CustomPartial") + + serialized = partial.serialize() + assert serialized["my_col"] == ModelStore.get_name(custom_partial) + assert "_custom_types" in serialized + + +def test_to_partial_complex_deeply_nested_signal(): + """Test to_partial with deeply nested complex signals (3+ levels).""" + from datachain.lib.file import ImageFile + + class Level1(DataModel): + image: ImageFile + name: str + + class Level2(DataModel): + level1: Level1 + category: str + + class Level3(DataModel): + level2: Level2 + id: str + + schema = SignalSchema({"deep": Level3, "simple": str}) + + # Test deeply nested complex signal + partial = schema.to_partial("deep.level2.level1.image") + + deep_partial = partial.values["deep"] + level2_partial = deep_partial.model_fields["level2"].annotation + level1_partial = level2_partial.model_fields["level1"].annotation + + assert issubclass(level1_partial, DataModel) + assert set(level1_partial.model_fields) == {"image"} + assert level1_partial.model_fields["image"].annotation is ImageFile + assert deep_partial.__name__.startswith("Level3Partial") + assert level2_partial.__name__.startswith("Level2Partial") + assert level1_partial.__name__.startswith("Level1Partial") + + serialized = partial.serialize() + assert serialized["deep"] == ModelStore.get_name(deep_partial) + assert ModelStore.get_name(level1_partial) in serialized["_custom_types"] + + +def test_to_partial_complex_nested_multiple_complex_signals(): + """Test to_partial with multiple nested complex signals.""" + + class Container(DataModel): + file1: File + file2: TextFile + name: str + + schema = SignalSchema({"container": Container, "simple": str}) + + # Request multiple nested complex signals + partial = schema.to_partial("container.file1", "container.file2") + + assert set(partial.values) == {"container"} + + container_partial = partial.values["container"] + assert issubclass(container_partial, DataModel) + assert container_partial.model_fields["file1"].annotation is File + assert container_partial.model_fields["file2"].annotation is TextFile + assert container_partial.__name__.startswith("ContainerPartial") + + serialized = partial.serialize() + assert serialized["container"] == ModelStore.get_name(container_partial) + + +def test_to_partial_complex_nested_mixed_complex_and_simple(): + """Test to_partial with mix of nested complex signals and simple fields.""" + + class Container(DataModel): + file: File + name: str + count: int + + schema = SignalSchema({"container": Container, "simple": str}) + + # Request mix of nested complex signal and simple field + partial = schema.to_partial("container.file", "container.name", "simple") + + assert set(partial.values) == {"container", "simple"} + assert partial.values["simple"] is str + + container_partial = partial.values["container"] + assert issubclass(container_partial, DataModel) + assert container_partial.model_fields["file"].annotation is File + assert container_partial.model_fields["name"].annotation is str + assert container_partial.__name__.startswith("ContainerPartial") + + serialized = partial.serialize() + assert serialized["container"] == ModelStore.get_name(container_partial) + assert serialized["simple"] == "str" + + +def test_to_partial_complex_nested_same_type_different_paths(): + """Test to_partial with same complex type accessed via different nested paths.""" + + class Container1(DataModel): + file: File + name: str + + class Container2(DataModel): + file: File + category: str + + schema = SignalSchema({"cont1": Container1, "cont2": Container2}) + + # Request same complex type from different nested paths + partial = schema.to_partial("cont1.file", "cont2.file") + + assert set(partial.values) == {"cont1", "cont2"} + + cont1_partial = partial.values["cont1"] + cont2_partial = partial.values["cont2"] + assert issubclass(cont1_partial, DataModel) + assert issubclass(cont2_partial, DataModel) + assert cont1_partial is not cont2_partial + + assert cont1_partial.model_fields["file"].annotation is File + assert cont2_partial.model_fields["file"].annotation is File + assert cont1_partial.__name__.startswith("Container1Partial") + assert cont2_partial.__name__.startswith("Container2Partial") + + serialized = partial.serialize() + assert serialized["cont1"] == ModelStore.get_name(cont1_partial) + assert serialized["cont2"] == ModelStore.get_name(cont2_partial) + + +def test_to_partial_complex_signal_file_single_field(): + """Test to_partial with File complex signal - single field.""" + schema = SignalSchema({"name": str, "file": File}) + partial = schema.to_partial("file.path") + + assert set(partial.values) == {"file"} + + file_partial = partial.values["file"] + assert issubclass(file_partial, DataModel) + assert set(file_partial.model_fields) == {"path"} + assert file_partial.model_fields["path"].annotation is str + assert file_partial.__name__.startswith("FilePartial") + + serialized = partial.serialize() + assert serialized["file"] == ModelStore.get_name(file_partial) + + +def test_to_partial_complex_signal_mixed_entire_and_fields(): + """Test to_partial with mix of entire complex signal and specific fields.""" + schema = SignalSchema({"file1": File, "file2": File, "name": str}) + partial = schema.to_partial("file1", "file2.path", "name") + + assert set(partial.values) == {"file1", "file2", "name"} + + assert partial.values["file1"] is File + assert partial.values["name"] is str + + file2_partial = partial.values["file2"] + assert issubclass(file2_partial, DataModel) + assert set(file2_partial.model_fields) == {"path"} + assert file2_partial.model_fields["path"].annotation is str + assert file2_partial.__name__.startswith("FilePartial") + + serialized = partial.serialize() + assert serialized["file1"] == "File@v1" + assert serialized["file2"] == ModelStore.get_name(file2_partial) + assert serialized["name"] == "str" + assert ModelStore.get_name(file2_partial) in serialized["_custom_types"] + + +def test_to_partial_complex_signal_multiple_entire_files(): + """Test to_partial with multiple entire complex signals.""" + schema = SignalSchema({"file1": File, "file2": File, "name": str}) + partial = schema.to_partial("file1", "file2") + + assert set(partial.values) == {"file1", "file2"} + assert partial.values["file1"] is File + assert partial.values["file2"] is File + + +def test_to_partial_complex_signal_nested_entire(): + """Test to_partial with nested complex signal - entire parent.""" + + class Container(DataModel): + name: str + file: File + + schema = SignalSchema({"container": Container, "simple": str}) + partial = schema.to_partial("container") + + assert set(partial.values) == {"container"} + + container_type = partial.values["container"] + assert issubclass(container_type, DataModel) + assert set(container_type.model_fields) == {"name", "file"} + assert container_type.model_fields["name"].annotation is str + assert container_type.model_fields["file"].annotation is File + + +def test_to_partial_complex_signal_empty_request(): + """Test to_partial with no columns requested.""" + schema = SignalSchema({"file": File, "name": str}) + partial = schema.to_partial() + + # Should return empty schema + assert partial.values == {} + + +def test_to_partial_complex_signal_error_invalid_signal(): + """Test to_partial with invalid signal name.""" + schema = SignalSchema({"file": File}) + + with pytest.raises( + SignalSchemaError, match="Column nonexistent not found in the schema" + ): + schema.to_partial("nonexistent") + + +def test_to_partial_complex_signal_error_invalid_field(): + """Test to_partial with invalid field in complex signal.""" + schema = SignalSchema({"file": File}) + + with pytest.raises( + SignalSchemaError, + match=r"Field nonexistent not found in custom type File", + ): + schema.to_partial("file.nonexistent") + + +def test_to_partial_rejects_non_string_column(): + schema = SignalSchema({"name": str}) + + with pytest.raises( + SignalResolvingTypeError, match=r"to_partial\(\) supports only `str` type" + ): + schema.to_partial(123) + + +def test_to_partial_nested_on_scalar_column(): + schema = SignalSchema({"name": str}) + + with pytest.raises( + SignalSchemaError, match=r"Column name\.path not found in the schema" + ): + schema.to_partial("name.path") + + +def test_to_partial_prefers_whole_selection_over_fields(): + schema = SignalSchema({"file": File}) + + partial = schema.to_partial("file", "file.path") + + assert partial.values == {"file": File} + + +def test_to_partial_propagates_optional_default(): + class WithDefault(DataModel): + required: int + optional: str | None = Field(default="fallback") + + schema = SignalSchema({"data": WithDefault}) + + partial = schema.to_partial("data.optional") + partial_model = ModelStore.to_pydantic(partial.values["data"]) + assert partial_model is not None + + assert set(partial_model.model_fields) == {"optional"} + + optional_field = partial_model.model_fields["optional"] + + original_field = WithDefault.model_fields["optional"] + + assert optional_field.default == "fallback" + assert optional_field.annotation is original_field.annotation + + +def test_to_partial_does_not_create_model_when_all_fields_selected(): + class WithDefault(DataModel): + required: int + optional: str | None = Field(default="fallback") + + schema = SignalSchema({"data": WithDefault}) + + partial = schema.to_partial("data.required", "data.optional") + + # When the selection includes all fields, return the original model type. + assert partial.values == {"data": WithDefault}