Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 54 additions & 1 deletion src/datachain/lib/data_model.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import hashlib
import inspect
import types
import uuid
Expand All @@ -8,8 +9,9 @@
from pydantic import AliasChoices, BaseModel, Field, create_model
from pydantic.fields import FieldInfo

from datachain import json
from datachain.lib.model_store import ModelStore
from datachain.lib.utils import normalize_col_names
from datachain.lib.utils import normalize_col_names, type_to_str

StandardType = (
type[int]
Expand Down Expand Up @@ -52,6 +54,57 @@ def hidden_fields(cls) -> list[str]:
return cls._hidden_fields


def compute_model_fingerprint(
model: type[BaseModel], selection: dict[str, "dict[str, object] | None"]
) -> str:
"""
Compute a deterministic fingerprint for a model given a selection subtree.

Selection uses the same structure as SignalSchema.to_partial: a mapping from
field name -> nested selection dict or None (leaf).
"""

def _fingerprint_tree(
model_type: type[BaseModel], sel: dict[str, "dict[str, object] | None"]
) -> dict[str, object]:
tree: dict[str, object] = {}
for field_name, sub_sel in sorted(sel.items()):
if field_name not in model_type.model_fields:
raise ValueError(
f"Field {field_name} not found in {model_type.__name__}"
)

finfo = model_type.model_fields[field_name]
field_type = finfo.annotation
required = finfo.is_required()
entry: dict[str, object] = {
"type": type_to_str(field_type, register_pydantic=False),
"required": bool(required),
"default": None if required else repr(finfo.default),
}

child_model = ModelStore.to_pydantic(field_type)
if sub_sel is not None:
if child_model is None:
raise ValueError(
f"Field {field_name} in {model_type.__name__} is not a model"
)
entry["children"] = _fingerprint_tree(
child_model,
sub_sel, # type: ignore[arg-type]
)
tree[field_name] = entry

return tree

payload = {
"model": ModelStore.get_name(model),
"selection": _fingerprint_tree(model, selection),
}
json_str = json.dumps(payload, sort_keys=True, separators=(",", ":"))
return hashlib.sha256(json_str.encode("utf-8")).hexdigest()


def is_chain_type(t: type) -> bool:
"""Return true if type is supported by `DataChain`."""
if ModelStore.is_pydantic(t):
Expand Down
39 changes: 31 additions & 8 deletions src/datachain/lib/model_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,17 @@
class ModelStore:
store: ClassVar[dict[str, dict[int, type[BaseModel]]]] = {}

@staticmethod
def _base_name(model: type[BaseModel]) -> str:
# Some models are generated/restored with a versioned Python class name
# (e.g. `MyType_v1`) so that multiple versions can coexist in-process without
# `__name__` collisions.
#
# `_modelstore_base_name` preserves the original/logical name (e.g. `MyType`)
# after we alter the class name, so `ModelStore.get_name(model)` still
# returns `"MyType@v{model._version}"` and schema serialization stays stable.
return getattr(model, "_modelstore_base_name", model.__name__)

@classmethod
def get_version(cls, model: type[BaseModel]) -> int:
if not hasattr(model, "_version"):
Expand All @@ -15,21 +26,29 @@ def get_version(cls, model: type[BaseModel]) -> int:

@classmethod
def get_name(cls, model) -> str:
base_name = cls._base_name(model)
if (version := cls.get_version(model)) > 0:
return f"{model.__name__}@v{version}"
return model.__name__
return f"{base_name}@v{version}"
return base_name

@classmethod
def register(cls, fr: type):
"""Register a class as a data model for deserialization."""
if (model := ModelStore.to_pydantic(fr)) is None:
return

name = model.__name__
if name not in cls.store:
cls.store[name] = {}
base_name = cls._base_name(model)
unique_name = model.__name__
version = ModelStore.get_version(model)
cls.store[name][version] = model

# Register under both:
# - `base_name` (logical/original name, from `_modelstore_base_name` when set)
# so callers can resolve by original name + version, e.g.
# `ModelStore.get("Foo", 1)` after deserialization/restore.
# - `unique_name` (Python class `__name__`, e.g. `Foo_v1`) so callers can also
# resolve by the runtime class name.
for name in (base_name, unique_name):
cls.store.setdefault(name, {})[version] = model

for f_info in model.model_fields.values():
if (anno := ModelStore.to_pydantic(f_info.annotation)) is not None:
Expand Down Expand Up @@ -62,8 +81,12 @@ def parse_name_version(cls, fullname: str) -> tuple[str, int]:
@classmethod
def remove(cls, fr: type) -> None:
version = fr._version # type: ignore[attr-defined]
if fr.__name__ in cls.store and version in cls.store[fr.__name__]:
del cls.store[fr.__name__][version]
base_name = cls._base_name(fr)
unique_name = fr.__name__

for name in (base_name, unique_name):
if name in cls.store and version in cls.store[name]:
del cls.store[name][version]

@staticmethod
def is_pydantic(val: Any) -> bool:
Expand Down
Loading