Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 72 additions & 16 deletions src/datachain/delta.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import datachain
from datachain.dataset import DatasetDependency
from datachain.error import DatasetNotFoundError
from datachain.project import Project

if TYPE_CHECKING:
from typing_extensions import Concatenate, ParamSpec
Expand Down Expand Up @@ -50,15 +51,24 @@

def _get_delta_chain(
source_ds_name: str,
source_ds_project: Project,
source_ds_version: str,
source_ds_latest_version: str,
on: Union[str, Sequence[str]],
compare: Optional[Union[str, Sequence[str]]] = None,
) -> "DataChain":
"""Get delta chain for processing changes between versions."""
source_dc = datachain.read_dataset(source_ds_name, version=source_ds_version)
source_dc = datachain.read_dataset(
source_ds_name,
namespace=source_ds_project.namespace.name,
project=source_ds_project.name,
version=source_ds_version,
)
source_dc_latest = datachain.read_dataset(
source_ds_name, version=source_ds_latest_version
source_ds_name,
namespace=source_ds_project.namespace.name,
project=source_ds_project.name,
version=source_ds_latest_version,
)

# Calculate diff between source versions
Expand All @@ -67,8 +77,10 @@

def _get_retry_chain(
name: str,
project: Project,
latest_version: str,
source_ds_name: str,
source_ds_project: Project,
source_ds_version: str,
on: Union[str, Sequence[str]],
right_on: Optional[Union[str, Sequence[str]]],
Expand All @@ -82,8 +94,18 @@
retry_chain = None

# Read the latest version of the result dataset for retry logic
result_dataset = datachain.read_dataset(name, version=latest_version)
source_dc = datachain.read_dataset(source_ds_name, version=source_ds_version)
result_dataset = datachain.read_dataset(
name,
namespace=project.namespace.name,
project=project.name,
version=latest_version,
)
source_dc = datachain.read_dataset(
source_ds_name,
namespace=source_ds_project.namespace.name,
project=source_ds_project.name,
version=source_ds_version,
)

# Handle error records if delta_retry is a string (column name)
if isinstance(delta_retry, str):
Expand All @@ -106,10 +128,15 @@

def _get_source_info(
name: str,
project: Project,
latest_version: str,
catalog,
) -> tuple[
Optional[str], Optional[str], Optional[str], Optional[list[DatasetDependency]]
Optional[str],
Optional[Project],
Optional[str],
Optional[str],
Optional[list[DatasetDependency]],
]:
"""Get source dataset information and dependencies.

Expand All @@ -118,23 +145,34 @@
Returns (None, None, None, None) if source dataset was removed.
"""
dependencies = catalog.get_dataset_dependencies(
name, latest_version, indirect=False
name, latest_version, project=project, indirect=False
)

dep = dependencies[0]
if not dep:
# Starting dataset was removed, back off to normal dataset creation
return None, None, None, None
return None, None, None, None, None

Check warning on line 154 in src/datachain/delta.py

View check run for this annotation

Codecov / codecov/patch

src/datachain/delta.py#L154

Added line #L154 was not covered by tests

source_ds_project = catalog.metastore.get_project(dep.project, dep.namespace)
source_ds_name = dep.name
source_ds_version = dep.version
source_ds_latest_version = catalog.get_dataset(source_ds_name).latest_version

return source_ds_name, source_ds_version, source_ds_latest_version, dependencies
source_ds_latest_version = catalog.get_dataset(
source_ds_name, project=source_ds_project
).latest_version

return (
source_ds_name,
source_ds_project,
source_ds_version,
source_ds_latest_version,
dependencies,
)


def delta_retry_update(
dc: "DataChain",
namespace_name: str,
project_name: str,
name: str,
on: Union[str, Sequence[str]],
right_on: Optional[Union[str, Sequence[str]]] = None,
Expand Down Expand Up @@ -173,11 +211,12 @@
"""

catalog = dc.session.catalog
project = catalog.metastore.get_project(project_name, namespace_name)
dc._query.apply_listing_pre_step()

# Check if dataset exists
try:
dataset = catalog.get_dataset(name)
dataset = catalog.get_dataset(name, project=project)
latest_version = dataset.latest_version
except DatasetNotFoundError:
# First creation of result dataset
Expand All @@ -189,19 +228,29 @@
retry_chain = None
processing_chain = None

source_ds_name, source_ds_version, source_ds_latest_version, dependencies = (
_get_source_info(name, latest_version, catalog)
)
(
source_ds_name,
source_ds_project,
source_ds_version,
source_ds_latest_version,
dependencies,
) = _get_source_info(name, project, latest_version, catalog)

# If source_ds_name is None, starting dataset was removed
if source_ds_name is None:
return None, None, True

assert source_ds_project
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion (bug_risk): Use of assert for runtime validation may be bypassed in optimized mode.

Asserts are skipped with Python optimizations. Raise an exception instead if source_ds_project is essential.

Suggested change
assert source_ds_project
if not source_ds_project:
raise ValueError("source_ds_project must be set")

assert source_ds_version
assert source_ds_latest_version

diff_chain = _get_delta_chain(
source_ds_name, source_ds_version, source_ds_latest_version, on, compare
source_ds_name,
source_ds_project,
source_ds_version,
source_ds_latest_version,
on,
compare,
)

# Filter out removed dep
Expand All @@ -215,8 +264,10 @@
if delta_retry:
retry_chain = _get_retry_chain(
name,
project,
latest_version,
source_ds_name,
source_ds_project,
source_ds_version,
on,
right_on,
Expand All @@ -237,7 +288,12 @@
if processing_chain is None or (processing_chain and processing_chain.empty):
return None, None, False

latest_dataset = datachain.read_dataset(name, version=latest_version)
latest_dataset = datachain.read_dataset(
name,
namespace=project.namespace.name,
project=project.name,
version=latest_version,
)
compared_chain = latest_dataset.diff(
processing_chain,
on=right_on or on,
Expand Down
2 changes: 2 additions & 0 deletions src/datachain/lib/dc/datachain.py
Original file line number Diff line number Diff line change
Expand Up @@ -598,6 +598,8 @@ def save( # type: ignore[override]

result_ds, dependencies, has_changes = delta_retry_update(
self,
namespace_name,
project_name,
name,
on=self._delta_on,
right_on=self._delta_result_on,
Expand Down
25 changes: 20 additions & 5 deletions tests/func/test_delta.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,26 @@
def _get_dependencies(catalog, name, version) -> list[tuple[str, str]]:
return sorted(
[
(d.name, d.version)
(f"{d.namespace}.{d.project}.{d.name}", d.version)
for d in catalog.get_dataset_dependencies(name, version, indirect=False)
]
)


def test_delta_update_from_dataset(test_session, tmp_dir, tmp_path):
@pytest.mark.parametrize("project", ("global.dev", ""))
def test_delta_update_from_dataset(test_session, tmp_dir, tmp_path, project):
catalog = test_session.catalog
starting_ds_name = "starting_ds"
default_namespace_name = catalog.metastore.default_namespace_name
default_project_name = catalog.metastore.default_project_name

if project:
starting_ds_name = f"{project}.starting_ds"
dependency_ds_name = starting_ds_name
else:
starting_ds_name = "starting_ds"
dependency_ds_name = (
f"{default_namespace_name}.{default_project_name}.{starting_ds_name}"
)
ds_name = "delta_ds"

images = [
Expand Down Expand Up @@ -55,12 +66,16 @@ def create_delta_dataset(ds_name):
create_image_dataset(starting_ds_name, images[:2])
# first version of delta dataset
create_delta_dataset(ds_name)
assert _get_dependencies(catalog, ds_name, "1.0.0") == [(starting_ds_name, "1.0.0")]
assert _get_dependencies(catalog, ds_name, "1.0.0") == [
(dependency_ds_name, "1.0.0")
]
# second version of starting dataset
create_image_dataset(starting_ds_name, images[2:])
# second version of delta dataset
create_delta_dataset(ds_name)
assert _get_dependencies(catalog, ds_name, "1.0.1") == [(starting_ds_name, "1.0.1")]
assert _get_dependencies(catalog, ds_name, "1.0.1") == [
(dependency_ds_name, "1.0.1")
]

assert (dc.read_dataset(ds_name, version="1.0.0").order_by("file.path")).to_values(
"file.path"
Expand Down