Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/datachain/lib/dc/datachain.py
Original file line number Diff line number Diff line change
Expand Up @@ -657,7 +657,9 @@ def save( # type: ignore[override]
# current latest version instead.
from .datasets import read_dataset

return read_dataset(name, **kwargs)
return read_dataset(
name, namespace=namespace_name, project=project_name, **kwargs
)

return self._evolve(
query=self._query.save(
Expand Down
90 changes: 86 additions & 4 deletions tests/func/test_delta.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,35 @@
from datachain.lib.file import File, ImageFile


def _get_short_ds_name(catalog, name, project_name, namespace_name) -> str:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why _get_short_ds_name? IMO it should be something like _get_full_ds_name 🤔

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is trying to get the shortest name (check if namespace passed is default and project is default and doesn't use them)

if project_name == catalog.metastore.default_project.name:
return name
if namespace_name == catalog.metastore.default_project.namespace.name:
return f"{project_name}.{name}"
return f"{namespace_name}.{project_name}.{name}"


def _get_dependencies(catalog, name, version) -> list[tuple[str, str]]:
namespace_name, project_name, name = catalog.get_full_dataset_name(name)
return sorted(
[
(d.name, d.version)
for d in catalog.get_dataset_dependencies(name, version, indirect=False)
(_get_short_ds_name(catalog, d.name, d.project, d.namespace), d.version)
for d in catalog.get_dataset_dependencies(
name,
version,
project_name=project_name,
namespace_name=namespace_name,
indirect=False,
)
]
)


def test_delta_update_from_dataset(test_session, tmp_dir, tmp_path):
catalog = test_session.catalog

starting_ds_name = "starting_ds"
ds_name = "delta_ds"
starting_ds_name = "project.starting_ds"
ds_name = "project.delta_ds"

images = [
{"name": "img1.jpg", "data": Image.new(mode="RGB", size=(64, 64))},
Expand Down Expand Up @@ -82,6 +97,73 @@ def create_delta_dataset(ds_name):
create_delta_dataset(ds_name)


def test_delta_returns_correct_dataset_on_no_changes(test_session):
catalog = test_session.catalog

default_project_name = catalog.metastore.default_project.name
default_namespace_name = catalog.metastore.default_project.namespace.name

base_short = "same_name_base"
delta_short = "same_name_delta"

cases = [
{"ns": default_namespace_name, "proj": default_project_name, "ids": [1, 2]},
{"ns": default_namespace_name, "proj": "project_other", "ids": [10, 20, 30]},
{"ns": "namespace_other", "proj": "project_other", "ids": [100, 200]},
]

# First pass: create starting and delta datasets (v1)
for case in cases:
ns = case["ns"]
proj = case["proj"]
ids = case["ids"]

starting_ds_name = _get_short_ds_name(catalog, base_short, proj, ns)
delta_ds_name = _get_short_ds_name(catalog, delta_short, proj, ns)

dc.read_values(id=ids, session=test_session).save(starting_ds_name)

dc.read_dataset(
starting_ds_name,
session=test_session,
delta=True,
delta_on="id",
delta_compare="id",
).save(delta_ds_name)

assert _get_dependencies(catalog, delta_ds_name, "1.0.0") == [
(starting_ds_name, "1.0.0")
]
assert set(
dc.read_dataset(delta_ds_name, version="1.0.0").to_values("id")
) == set(ids)

# Second pass: re-save with no changes, ensure it returns the existing version
for case in cases:
ns = case["ns"]
proj = case["proj"]
ids = case["ids"]

starting_ds_name = _get_short_ds_name(catalog, base_short, proj, ns)
delta_ds_name = _get_short_ds_name(catalog, delta_short, proj, ns)

res = dc.read_dataset(
starting_ds_name,
session=test_session,
delta=True,
delta_on="id",
delta_compare="id",
).save(delta_ds_name)

# Should return the dataset with the same contents (no changes)
assert res.dataset is not None
assert set(dc.read_dataset(delta_ds_name).to_values("id")) == set(ids)

# Still no newer version available
with pytest.raises(DatasetNotFoundError):
dc.read_dataset(delta_ds_name, version="1.0.1")


def test_delta_update_unsafe(test_session):
catalog = test_session.catalog

Expand Down
Loading