Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
189 changes: 127 additions & 62 deletions dvc/repo/data.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,25 @@
import os
import posixpath
from collections.abc import Iterable
from collections import defaultdict
from collections.abc import Iterable, Iterator, Mapping
from typing import TYPE_CHECKING, Optional, TypedDict, Union

from dvc.fs.callbacks import DEFAULT_CALLBACK
from dvc.fs.callbacks import DEFAULT_CALLBACK, Callback, TqdmCallback
from dvc.log import logger
from dvc.scm import RevError
from dvc.ui import ui
from dvc_data.index.view import DataIndexView

if TYPE_CHECKING:
from dvc.fs.callbacks import Callback
from dvc.repo import Repo
from dvc.scm import Git, NoSCM
from dvc_data.index import BaseDataIndex, DataIndex, DataIndexKey
from dvc_data.index import (
BaseDataIndex,
DataIndex,
DataIndexEntry,
DataIndexKey,
DataIndexView,
)
from dvc_data.index.diff import Change
from dvc_objects.fs.base import FileSystem

logger = logger.getChild(__name__)

Expand All @@ -23,21 +28,6 @@ def posixpath_to_os_path(path: str) -> str:
return path.replace(posixpath.sep, os.path.sep)


def _adapt_typ(typ: str) -> str:
from dvc_data.index.diff import ADD, DELETE, MODIFY

if typ == MODIFY:
return "modified"

if typ == ADD:
return "added"

if typ == DELETE:
return "deleted"

return typ


def _adapt_path(change: "Change") -> str:
isdir = False
if change.new and change.new.meta:
Expand All @@ -50,25 +40,65 @@ def _adapt_path(change: "Change") -> str:
return os.path.sep.join(key)


def _get_missing_paths(
to_check: Mapping["FileSystem", Mapping[str, Iterable["DataIndexEntry"]]],
batch_size: Optional[int] = None,
callback: "Callback" = DEFAULT_CALLBACK,
) -> Iterator[str]:
for fs, paths_map in to_check.items():
if batch_size == 1 or (batch_size is None and fs.protocol == "local"):
results = list(callback.wrap(map(fs.exists, paths_map)))
else:
results = fs.exists(
list(paths_map), batch_size=batch_size, callback=callback
)

for cache_path, exists in zip(paths_map, results):
if exists:
continue

for entry in paths_map[cache_path]:
key = entry.key
assert key
if entry.meta and entry.meta.isdir:
key = (*key, "")
yield os.path.sep.join(key)


class StorageCallback(Callback):
def __init__(self, parent_cb: Callback) -> None:
super().__init__(size=0, value=0)
self.parent_cb = parent_cb

def set_size(self, size: int) -> None:
# This is a no-op to prevent `fs.exists` from trying to set the size
pass

def relative_update(self, value: int = 1) -> None:
self.parent_cb.relative_update(value)

def absolute_update(self, value: int) -> None:
self.parent_cb.relative_update(value - self.value)


def _diff(
old: "BaseDataIndex",
new: "BaseDataIndex",
*,
filter_keys: Optional[Iterable["DataIndexKey"]] = None,
granular: bool = False,
not_in_cache: bool = False,
batch_size: Optional[int] = None,
callback: "Callback" = DEFAULT_CALLBACK,
filter_keys: Optional[list["DataIndexKey"]] = None,
) -> dict[str, list[str]]:
from dvc_data.index.diff import UNCHANGED, UNKNOWN, diff
from dvc_data.index.diff import ADD, DELETE, MODIFY, UNCHANGED, UNKNOWN, diff

ret: dict[str, list[str]] = {}
ret: dict[str, list[str]] = defaultdict(list)
change_types = {MODIFY: "modified", ADD: "added", DELETE: "deleted"}

def _add_change(typ, change):
typ = _adapt_typ(typ)
if typ not in ret:
ret[typ] = []

ret[typ].append(_adapt_path(change))
to_check: dict[FileSystem, dict[str, list[DataIndexEntry]]] = defaultdict(
lambda: defaultdict(list)
)

for change in diff(
old,
Expand All @@ -84,9 +114,7 @@ def _add_change(typ, change):
# still appear in the view. As a result, keys like `dir/` will be present
# even if only `dir/file` matches the filter.
# We need to skip such entries to avoid showing root of tracked directories.
if filter_keys and not any(
change.key[: len(filter_key)] == filter_key for filter_key in filter_keys
):
if filter_keys and not any(change.key[: len(fk)] == fk for fk in filter_keys):
continue

if (
Expand All @@ -101,18 +129,27 @@ def _add_change(typ, change):
# NOTE: emulating previous behaviour
continue

if (
not_in_cache
and change.old
and change.old.hash_info
and not old.storage_map.cache_exists(change.old)
):
# NOTE: emulating previous behaviour
_add_change("not_in_cache", change)
if not_in_cache and change.old and change.old.hash_info:
old_entry = change.old
cache_fs, cache_path = old.storage_map.get_cache(old_entry)
# check later in batches
to_check[cache_fs][cache_path].append(old_entry)

_add_change(change.typ, change)
change_typ = change_types.get(change.typ, change.typ)
ret[change_typ].append(_adapt_path(change))

return ret
total_items = sum(
len(entries) for paths in to_check.values() for entries in paths.values()
)
with TqdmCallback(size=total_items, desc="Checking cache", unit="entry") as cb:
missing_items = list(
_get_missing_paths(
to_check, batch_size=batch_size, callback=StorageCallback(cb)
),
)
if missing_items:
ret["not_in_cache"] = missing_items
return dict(ret)


class GitInfo(TypedDict, total=False):
Expand Down Expand Up @@ -153,8 +190,10 @@ def _git_info(scm: Union["Git", "NoSCM"], untracked_files: str = "all") -> GitIn

def filter_index(
index: Union["DataIndex", "DataIndexView"],
filter_keys: Optional[list["DataIndexKey"]] = None,
filter_keys: Optional[Iterable["DataIndexKey"]] = None,
) -> "BaseDataIndex":
from dvc_data.index.view import DataIndexView

if not filter_keys:
return index

Expand Down Expand Up @@ -187,8 +226,9 @@ def filter_fn(key: "DataIndexKey") -> bool:

def _diff_index_to_wtree(
repo: "Repo",
filter_keys: Optional[list["DataIndexKey"]] = None,
filter_keys: Optional[Iterable["DataIndexKey"]] = None,
granular: bool = False,
batch_size: Optional[int] = None,
) -> dict[str, list[str]]:
from .index import build_data_index

Expand All @@ -214,16 +254,18 @@ def _diff_index_to_wtree(
filter_keys=filter_keys,
granular=granular,
not_in_cache=True,
batch_size=batch_size,
callback=pb.as_callback(),
)


def _diff_head_to_index(
repo: "Repo",
head: str = "HEAD",
filter_keys: Optional[list["DataIndexKey"]] = None,
filter_keys: Optional[Iterable["DataIndexKey"]] = None,
granular: bool = False,
) -> dict[str, list[str]]:
from dvc.scm import RevError
from dvc_data.index import DataIndex

index = repo.index.data["repo"]
Expand Down Expand Up @@ -278,9 +320,10 @@ def _transform_git_paths_to_dvc(repo: "Repo", files: Iterable[str]) -> list[str]

def _get_entries_not_in_remote(
repo: "Repo",
filter_keys: Optional[list["DataIndexKey"]] = None,
filter_keys: Optional[Iterable["DataIndexKey"]] = None,
granular: bool = False,
remote_refresh: bool = False,
batch_size: Optional[int] = None,
) -> list[str]:
"""Get entries that are not in remote storage."""
from dvc.repo.worktree import worktree_view
Expand All @@ -293,7 +336,13 @@ def _get_entries_not_in_remote(
view = filter_index(data_index, filter_keys=filter_keys) # type: ignore[arg-type]

missing_entries = []
with ui.progress(desc="Checking remote", unit="entry") as pb:

to_check: dict[FileSystem, dict[str, list[DataIndexEntry]]] = defaultdict(
lambda: defaultdict(list)
)

storage_map = view.storage_map
with TqdmCallback(size=0, desc="Checking remote", unit="entry") as cb:
for key, entry in view.iteritems(shallow=not granular):
if not (entry and entry.hash_info):
continue
Expand All @@ -309,13 +358,28 @@ def _get_entries_not_in_remote(
continue

k = (*key, "") if entry.meta and entry.meta.isdir else key
try:
if not view.storage_map.remote_exists(entry, refresh=remote_refresh):
missing_entries.append(os.path.sep.join(k))
pb.update()
except StorageKeyError:
pass

if remote_refresh:
# on remote_refresh, collect all entries to check
# then check them in batches below
try:
remote_fs, remote_path = storage_map.get_remote(entry)
to_check[remote_fs][remote_path].append(entry)
cb.size += 1
cb.relative_update(0) # try to update the progress bar
except StorageKeyError:
pass
else:
try:
if not storage_map.remote_exists(entry, refresh=remote_refresh):
missing_entries.append(os.path.sep.join(k))
cb.relative_update() # no need to update the size
except StorageKeyError:
pass
missing_entries.extend(
_get_missing_paths(
to_check, batch_size=batch_size, callback=StorageCallback(cb)
)
)
return missing_entries


Expand All @@ -324,7 +388,7 @@ def _matches_target(p: str, targets: Iterable[str]) -> bool:
return any(p == t or p.startswith(t + sep) for t in targets)


def _prune_keys(filter_keys: list["DataIndexKey"]) -> list["DataIndexKey"]:
def _prune_keys(filter_keys: Iterable["DataIndexKey"]) -> list["DataIndexKey"]:
sorted_keys = sorted(set(filter_keys), key=len)
result: list[DataIndexKey] = []

Expand All @@ -342,6 +406,7 @@ def status(
untracked_files: str = "no",
not_in_remote: bool = False,
remote_refresh: bool = False,
batch_size: Optional[int] = None,
head: str = "HEAD",
) -> Status:
from dvc.scm import NoSCMError, SCMError
Expand All @@ -352,19 +417,19 @@ def status(
filter_keys = _prune_keys(filter_keys)

uncommitted_diff = _diff_index_to_wtree(
repo, filter_keys=filter_keys, granular=granular
repo, filter_keys=filter_keys, granular=granular, batch_size=batch_size
)
unchanged = set(uncommitted_diff.pop("unchanged", []))
entries_not_in_remote = (
_get_entries_not_in_remote(

entries_not_in_remote: list[str] = []
if not_in_remote:
entries_not_in_remote = _get_entries_not_in_remote(
repo,
filter_keys=filter_keys,
granular=granular,
remote_refresh=remote_refresh,
batch_size=batch_size,
)
if not_in_remote
else []
)

try:
committed_diff = _diff_head_to_index(
Expand Down