diff --git a/dvc/repo/data.py b/dvc/repo/data.py index 0f2cd71728..da0e555297 100644 --- a/dvc/repo/data.py +++ b/dvc/repo/data.py @@ -1,20 +1,25 @@ import os import posixpath -from collections.abc import Iterable +from collections import defaultdict +from collections.abc import Iterable, Iterator, Mapping from typing import TYPE_CHECKING, Optional, TypedDict, Union -from dvc.fs.callbacks import DEFAULT_CALLBACK +from dvc.fs.callbacks import DEFAULT_CALLBACK, Callback, TqdmCallback from dvc.log import logger -from dvc.scm import RevError from dvc.ui import ui -from dvc_data.index.view import DataIndexView if TYPE_CHECKING: - from dvc.fs.callbacks import Callback from dvc.repo import Repo from dvc.scm import Git, NoSCM - from dvc_data.index import BaseDataIndex, DataIndex, DataIndexKey + from dvc_data.index import ( + BaseDataIndex, + DataIndex, + DataIndexEntry, + DataIndexKey, + DataIndexView, + ) from dvc_data.index.diff import Change + from dvc_objects.fs.base import FileSystem logger = logger.getChild(__name__) @@ -23,21 +28,6 @@ def posixpath_to_os_path(path: str) -> str: return path.replace(posixpath.sep, os.path.sep) -def _adapt_typ(typ: str) -> str: - from dvc_data.index.diff import ADD, DELETE, MODIFY - - if typ == MODIFY: - return "modified" - - if typ == ADD: - return "added" - - if typ == DELETE: - return "deleted" - - return typ - - def _adapt_path(change: "Change") -> str: isdir = False if change.new and change.new.meta: @@ -50,25 +40,65 @@ def _adapt_path(change: "Change") -> str: return os.path.sep.join(key) +def _get_missing_paths( + to_check: Mapping["FileSystem", Mapping[str, Iterable["DataIndexEntry"]]], + batch_size: Optional[int] = None, + callback: "Callback" = DEFAULT_CALLBACK, +) -> Iterator[str]: + for fs, paths_map in to_check.items(): + if batch_size == 1 or (batch_size is None and fs.protocol == "local"): + results = list(callback.wrap(map(fs.exists, paths_map))) + else: + results = fs.exists( + list(paths_map), batch_size=batch_size, callback=callback + ) + + for cache_path, exists in zip(paths_map, results): + if exists: + continue + + for entry in paths_map[cache_path]: + key = entry.key + assert key + if entry.meta and entry.meta.isdir: + key = (*key, "") + yield os.path.sep.join(key) + + +class StorageCallback(Callback): + def __init__(self, parent_cb: Callback) -> None: + super().__init__(size=0, value=0) + self.parent_cb = parent_cb + + def set_size(self, size: int) -> None: + # This is a no-op to prevent `fs.exists` from trying to set the size + pass + + def relative_update(self, value: int = 1) -> None: + self.parent_cb.relative_update(value) + + def absolute_update(self, value: int) -> None: + self.parent_cb.relative_update(value - self.value) + + def _diff( old: "BaseDataIndex", new: "BaseDataIndex", *, + filter_keys: Optional[Iterable["DataIndexKey"]] = None, granular: bool = False, not_in_cache: bool = False, + batch_size: Optional[int] = None, callback: "Callback" = DEFAULT_CALLBACK, - filter_keys: Optional[list["DataIndexKey"]] = None, ) -> dict[str, list[str]]: - from dvc_data.index.diff import UNCHANGED, UNKNOWN, diff + from dvc_data.index.diff import ADD, DELETE, MODIFY, UNCHANGED, UNKNOWN, diff - ret: dict[str, list[str]] = {} + ret: dict[str, list[str]] = defaultdict(list) + change_types = {MODIFY: "modified", ADD: "added", DELETE: "deleted"} - def _add_change(typ, change): - typ = _adapt_typ(typ) - if typ not in ret: - ret[typ] = [] - - ret[typ].append(_adapt_path(change)) + to_check: dict[FileSystem, dict[str, list[DataIndexEntry]]] = defaultdict( + lambda: defaultdict(list) + ) for change in diff( old, @@ -84,9 +114,7 @@ def _add_change(typ, change): # still appear in the view. As a result, keys like `dir/` will be present # even if only `dir/file` matches the filter. # We need to skip such entries to avoid showing root of tracked directories. - if filter_keys and not any( - change.key[: len(filter_key)] == filter_key for filter_key in filter_keys - ): + if filter_keys and not any(change.key[: len(fk)] == fk for fk in filter_keys): continue if ( @@ -101,18 +129,27 @@ def _add_change(typ, change): # NOTE: emulating previous behaviour continue - if ( - not_in_cache - and change.old - and change.old.hash_info - and not old.storage_map.cache_exists(change.old) - ): - # NOTE: emulating previous behaviour - _add_change("not_in_cache", change) + if not_in_cache and change.old and change.old.hash_info: + old_entry = change.old + cache_fs, cache_path = old.storage_map.get_cache(old_entry) + # check later in batches + to_check[cache_fs][cache_path].append(old_entry) - _add_change(change.typ, change) + change_typ = change_types.get(change.typ, change.typ) + ret[change_typ].append(_adapt_path(change)) - return ret + total_items = sum( + len(entries) for paths in to_check.values() for entries in paths.values() + ) + with TqdmCallback(size=total_items, desc="Checking cache", unit="entry") as cb: + missing_items = list( + _get_missing_paths( + to_check, batch_size=batch_size, callback=StorageCallback(cb) + ), + ) + if missing_items: + ret["not_in_cache"] = missing_items + return dict(ret) class GitInfo(TypedDict, total=False): @@ -153,8 +190,10 @@ def _git_info(scm: Union["Git", "NoSCM"], untracked_files: str = "all") -> GitIn def filter_index( index: Union["DataIndex", "DataIndexView"], - filter_keys: Optional[list["DataIndexKey"]] = None, + filter_keys: Optional[Iterable["DataIndexKey"]] = None, ) -> "BaseDataIndex": + from dvc_data.index.view import DataIndexView + if not filter_keys: return index @@ -187,8 +226,9 @@ def filter_fn(key: "DataIndexKey") -> bool: def _diff_index_to_wtree( repo: "Repo", - filter_keys: Optional[list["DataIndexKey"]] = None, + filter_keys: Optional[Iterable["DataIndexKey"]] = None, granular: bool = False, + batch_size: Optional[int] = None, ) -> dict[str, list[str]]: from .index import build_data_index @@ -214,6 +254,7 @@ def _diff_index_to_wtree( filter_keys=filter_keys, granular=granular, not_in_cache=True, + batch_size=batch_size, callback=pb.as_callback(), ) @@ -221,9 +262,10 @@ def _diff_index_to_wtree( def _diff_head_to_index( repo: "Repo", head: str = "HEAD", - filter_keys: Optional[list["DataIndexKey"]] = None, + filter_keys: Optional[Iterable["DataIndexKey"]] = None, granular: bool = False, ) -> dict[str, list[str]]: + from dvc.scm import RevError from dvc_data.index import DataIndex index = repo.index.data["repo"] @@ -278,9 +320,10 @@ def _transform_git_paths_to_dvc(repo: "Repo", files: Iterable[str]) -> list[str] def _get_entries_not_in_remote( repo: "Repo", - filter_keys: Optional[list["DataIndexKey"]] = None, + filter_keys: Optional[Iterable["DataIndexKey"]] = None, granular: bool = False, remote_refresh: bool = False, + batch_size: Optional[int] = None, ) -> list[str]: """Get entries that are not in remote storage.""" from dvc.repo.worktree import worktree_view @@ -293,7 +336,13 @@ def _get_entries_not_in_remote( view = filter_index(data_index, filter_keys=filter_keys) # type: ignore[arg-type] missing_entries = [] - with ui.progress(desc="Checking remote", unit="entry") as pb: + + to_check: dict[FileSystem, dict[str, list[DataIndexEntry]]] = defaultdict( + lambda: defaultdict(list) + ) + + storage_map = view.storage_map + with TqdmCallback(size=0, desc="Checking remote", unit="entry") as cb: for key, entry in view.iteritems(shallow=not granular): if not (entry and entry.hash_info): continue @@ -309,13 +358,28 @@ def _get_entries_not_in_remote( continue k = (*key, "") if entry.meta and entry.meta.isdir else key - try: - if not view.storage_map.remote_exists(entry, refresh=remote_refresh): - missing_entries.append(os.path.sep.join(k)) - pb.update() - except StorageKeyError: - pass - + if remote_refresh: + # on remote_refresh, collect all entries to check + # then check them in batches below + try: + remote_fs, remote_path = storage_map.get_remote(entry) + to_check[remote_fs][remote_path].append(entry) + cb.size += 1 + cb.relative_update(0) # try to update the progress bar + except StorageKeyError: + pass + else: + try: + if not storage_map.remote_exists(entry, refresh=remote_refresh): + missing_entries.append(os.path.sep.join(k)) + cb.relative_update() # no need to update the size + except StorageKeyError: + pass + missing_entries.extend( + _get_missing_paths( + to_check, batch_size=batch_size, callback=StorageCallback(cb) + ) + ) return missing_entries @@ -324,7 +388,7 @@ def _matches_target(p: str, targets: Iterable[str]) -> bool: return any(p == t or p.startswith(t + sep) for t in targets) -def _prune_keys(filter_keys: list["DataIndexKey"]) -> list["DataIndexKey"]: +def _prune_keys(filter_keys: Iterable["DataIndexKey"]) -> list["DataIndexKey"]: sorted_keys = sorted(set(filter_keys), key=len) result: list[DataIndexKey] = [] @@ -342,6 +406,7 @@ def status( untracked_files: str = "no", not_in_remote: bool = False, remote_refresh: bool = False, + batch_size: Optional[int] = None, head: str = "HEAD", ) -> Status: from dvc.scm import NoSCMError, SCMError @@ -352,19 +417,19 @@ def status( filter_keys = _prune_keys(filter_keys) uncommitted_diff = _diff_index_to_wtree( - repo, filter_keys=filter_keys, granular=granular + repo, filter_keys=filter_keys, granular=granular, batch_size=batch_size ) unchanged = set(uncommitted_diff.pop("unchanged", [])) - entries_not_in_remote = ( - _get_entries_not_in_remote( + + entries_not_in_remote: list[str] = [] + if not_in_remote: + entries_not_in_remote = _get_entries_not_in_remote( repo, filter_keys=filter_keys, granular=granular, remote_refresh=remote_refresh, + batch_size=batch_size, ) - if not_in_remote - else [] - ) try: committed_diff = _diff_head_to_index(