diff --git a/src/dvc_data/index/index.py b/src/dvc_data/index/index.py index 283066ad..b9fa3918 100644 --- a/src/dvc_data/index/index.py +++ b/src/dvc_data/index/index.py @@ -2,10 +2,13 @@ import logging import os from abc import ABC, abstractmethod +from collections import defaultdict from collections.abc import Iterator, MutableMapping from typing import TYPE_CHECKING, Any, Callable, Optional, cast import attrs +from fsspec import Callback +from fsspec.callbacks import DEFAULT_CALLBACK from sqltrie import JSONTrie, PyGTrie, ShortKeyError, SQLiteTrie from dvc_data.compat import cached_property @@ -156,6 +159,19 @@ def exists(self, entry: "DataIndexEntry") -> bool: fs, path = self.get(entry) return fs.exists(path) + def bulk_exists( + self, + entries: list["DataIndexEntry"], + refresh: bool = False, + jobs: Optional[int] = None, + callback: "Callback" = DEFAULT_CALLBACK, + ) -> dict["DataIndexEntry", bool]: + results = {} + for entry in callback.wrap(entries): + results[entry] = self.exists(entry) + + return results + class ObjectStorage(Storage): def __init__( @@ -224,6 +240,66 @@ def exists(self, entry: "DataIndexEntry", refresh: bool = False) -> bool: finally: self.index.commit() + def bulk_exists( + self, + entries: list["DataIndexEntry"], + refresh: bool = False, + jobs: Optional[int] = None, + callback: "Callback" = DEFAULT_CALLBACK, + ) -> dict["DataIndexEntry", bool]: + if not entries: + return {} + + entries_with_hash = [e for e in entries if e.hash_info] + entries_without_hash = [e for e in entries if not e.hash_info] + results = dict.fromkeys(entries_without_hash, False) + callback.relative_update(len(entries_without_hash)) + + if self.index is None or not refresh: + for entry in callback.wrap(entries_with_hash): + assert entry.hash_info + value = cast("str", entry.hash_info.value) + if self.index is None: + exists = self.odb.exists(value) + else: + key = self.odb._oid_parts(value) + exists = key in self.index + results[entry] = exists + return results + + entry_map: dict[str, DataIndexEntry] = { + self.get(entry)[1]: entry for entry in entries_with_hash + } + info_results = self.fs.info( + list(entry_map.keys()), + batch_size=jobs, + return_exceptions=True, + callback=callback, + ) + + for (path, entry), info in zip(entry_map.items(), info_results): + assert entry.hash_info # built from entries_with_hash + value = cast("str", entry.hash_info.value) + key = self.odb._oid_parts(value) + + if isinstance(info, FileNotFoundError) or info is None: + self.index.pop(key, None) + results[entry] = False + continue + if isinstance(info, Exception): + raise info + + from .build import build_entry + + built_entry = build_entry(path, self.fs, info=info) + self.index[key] = built_entry + results[entry] = True + + if self.index is not None: + self.index.commit() + + return results + class FileStorage(Storage): def __init__( @@ -442,6 +518,69 @@ def remote_exists(self, entry: "DataIndexEntry", **kwargs) -> bool: return storage.remote.exists(entry, **kwargs) + def _bulk_storage_exists( + self, + entries: list[DataIndexEntry], + storage: str, + **kwargs, + ) -> dict[DataIndexEntry, bool]: + by_storage: dict[Storage, list[DataIndexEntry]] = defaultdict(list) + by_odb: dict[Optional[HashFileDB], dict[Storage, list[DataIndexEntry]]] = ( + defaultdict(lambda: defaultdict(list)) + ) + for entry in entries: + storage_info = self[entry.key] + storage_obj = getattr(storage_info, storage) if storage_info else None + if isinstance(storage_obj, ObjectStorage): + by_odb[storage_obj.odb][storage_obj].append(entry) + elif storage_obj is not None: + by_storage[storage_obj].append(entry) + + for storages in by_odb.values(): + assert storages # cannot be empty, we always add at least one entry + representative = next(iter(storages)) + by_storage[representative] = [ + e for entries in storages.values() for e in entries + ] + + results = {} + + for storage_obj, storage_entries in by_storage.items(): + results.update( + storage_obj.bulk_exists( + storage_entries, + **kwargs, + ) + ) + + return results + + def bulk_cache_exists( + self, + entries: list[DataIndexEntry], + callback: Callback = DEFAULT_CALLBACK, + **kwargs, + ) -> dict[DataIndexEntry, bool]: + return self._bulk_storage_exists( + entries, + "cache", + callback=callback, + **kwargs, + ) + + def bulk_remote_exists( + self, + entries: list[DataIndexEntry], + callback: Callback = DEFAULT_CALLBACK, + **kwargs, + ) -> dict[DataIndexEntry, bool]: + return self._bulk_storage_exists( + entries, + "remote", + callback=callback, + **kwargs, + ) + class BaseDataIndex(ABC, MutableMapping[DataIndexKey, DataIndexEntry]): storage_map: StorageMapping diff --git a/tests/index/test_storage.py b/tests/index/test_storage.py index c28d4523..737b98ed 100644 --- a/tests/index/test_storage.py +++ b/tests/index/test_storage.py @@ -1,6 +1,15 @@ from dvc_objects.fs.local import LocalFileSystem -from dvc_data.index import FileStorage, ObjectStorage, StorageInfo, StorageMapping +from dvc_data.hashfile.hash_info import HashInfo +from dvc_data.hashfile.meta import Meta +from dvc_data.index import ( + DataIndex, + DataIndexEntry, + FileStorage, + ObjectStorage, + StorageInfo, + StorageMapping, +) def test_map_get(tmp_path, odb): @@ -47,3 +56,241 @@ def test_map_get(tmp_path, odb): assert sinfo.data == data assert sinfo.cache == cache assert sinfo.remote == remote + + +class TestObjectStorageBulkExists: + def test_empty_entries(self, odb): + storage = ObjectStorage(key=(), odb=odb) + result = storage.bulk_exists([]) + assert result == {} + + def test_entries_without_hash(self, odb): + storage = ObjectStorage(key=(), odb=odb) + entry = DataIndexEntry(key=("foo",), meta=Meta()) + result = storage.bulk_exists([entry]) + assert result == {entry: False} + + def test_entries_exist_in_odb(self, odb): + storage = ObjectStorage(key=(), odb=odb) + entry = DataIndexEntry( + key=("foo",), + hash_info=HashInfo("md5", "d3b07384d113edec49eaa6238ad5ff00"), + ) + result = storage.bulk_exists([entry]) + assert result == {entry: True} + + def test_entries_not_in_odb(self, make_odb): + empty_odb = make_odb() + storage = ObjectStorage(key=(), odb=empty_odb) + entry = DataIndexEntry( + key=("foo",), + hash_info=HashInfo("md5", "nonexistent"), + ) + result = storage.bulk_exists([entry]) + assert result == {entry: False} + + def test_with_index_no_refresh(self, odb): + index = DataIndex() + key = odb._oid_parts("d3b07384d113edec49eaa6238ad5ff00") + index[key] = DataIndexEntry(key=key) + + storage = ObjectStorage(key=(), odb=odb, index=index) + entry_exists = DataIndexEntry( + key=("foo",), + hash_info=HashInfo("md5", "d3b07384d113edec49eaa6238ad5ff00"), + ) + entry_not_in_index = DataIndexEntry( + key=("bar",), + hash_info=HashInfo("md5", "c157a79031e1c40f85931829bc5fc552"), + ) + + result = storage.bulk_exists([entry_exists, entry_not_in_index], refresh=False) + assert result[entry_exists] is True + assert result[entry_not_in_index] is False + + def test_with_index_refresh_existing(self, odb): + index = DataIndex() + storage = ObjectStorage(key=(), odb=odb, index=index) + + entry_exists = DataIndexEntry( + key=("foo",), + hash_info=HashInfo("md5", "d3b07384d113edec49eaa6238ad5ff00"), + ) + + result = storage.bulk_exists([entry_exists], refresh=True) + assert result[entry_exists] is True + + key_exists = odb._oid_parts("d3b07384d113edec49eaa6238ad5ff00") + assert key_exists in index + + def test_mixed_entries(self, odb): + storage = ObjectStorage(key=(), odb=odb) + entry_with_hash = DataIndexEntry( + key=("foo",), + hash_info=HashInfo("md5", "d3b07384d113edec49eaa6238ad5ff00"), + ) + entry_without_hash = DataIndexEntry(key=("bar",), meta=Meta()) + + result = storage.bulk_exists([entry_with_hash, entry_without_hash]) + assert result[entry_with_hash] is True + assert result[entry_without_hash] is False + + def test_multiple_entries(self, odb): + storage = ObjectStorage(key=(), odb=odb) + entries = [ + DataIndexEntry( + key=("foo",), + hash_info=HashInfo("md5", "d3b07384d113edec49eaa6238ad5ff00"), + ), + DataIndexEntry( + key=("bar",), + hash_info=HashInfo("md5", "c157a79031e1c40f85931829bc5fc552"), + ), + DataIndexEntry( + key=("baz",), + hash_info=HashInfo("md5", "258622b1688250cb619f3c9ccaefb7eb"), + ), + ] + + result = storage.bulk_exists(entries) + assert all(result[e] is True for e in entries) + + +class TestStorageMappingBulkExists: + def test_bulk_cache_exists_empty(self, odb): + smap = StorageMapping() + smap.add_cache(ObjectStorage(key=(), odb=odb)) + result = smap.bulk_cache_exists([]) + assert result == {} + + def test_bulk_remote_exists_empty(self, odb): + smap = StorageMapping() + smap.add_remote(ObjectStorage(key=(), odb=odb)) + result = smap.bulk_remote_exists([]) + assert result == {} + + def test_bulk_cache_exists_all_exist(self, make_odb): + cache_odb = make_odb() + cache_odb.add_bytes("d3b07384d113edec49eaa6238ad5ff00", b"foo\n") + cache_odb.add_bytes("c157a79031e1c40f85931829bc5fc552", b"bar\n") + + smap = StorageMapping() + smap.add_cache(ObjectStorage(key=(), odb=cache_odb)) + + entries = [ + DataIndexEntry( + key=("foo",), + hash_info=HashInfo("md5", "d3b07384d113edec49eaa6238ad5ff00"), + ), + DataIndexEntry( + key=("bar",), + hash_info=HashInfo("md5", "c157a79031e1c40f85931829bc5fc552"), + ), + ] + + result = smap.bulk_cache_exists(entries) + assert all(result[e] is True for e in entries) + + def test_bulk_remote_exists_all_exist(self, odb): + smap = StorageMapping() + smap.add_remote(ObjectStorage(key=(), odb=odb)) + + entries = [ + DataIndexEntry( + key=("foo",), + hash_info=HashInfo("md5", "d3b07384d113edec49eaa6238ad5ff00"), + ), + DataIndexEntry( + key=("bar",), + hash_info=HashInfo("md5", "c157a79031e1c40f85931829bc5fc552"), + ), + ] + + result = smap.bulk_remote_exists(entries) + assert all(result[e] is True for e in entries) + + def test_bulk_cache_exists_missing_storage(self, odb): + smap = StorageMapping() + smap.add_remote(ObjectStorage(key=(), odb=odb)) + + entry = DataIndexEntry( + key=("foo",), + hash_info=HashInfo("md5", "d3b07384d113edec49eaa6238ad5ff00"), + ) + + result = smap.bulk_cache_exists([entry]) + # no cache storage, should be skipped + assert entry not in result + + def test_bulk_remote_exists_missing_storage(self, odb): + smap = StorageMapping() + smap.add_cache(ObjectStorage(key=(), odb=odb)) + + entry = DataIndexEntry( + key=("foo",), + hash_info=HashInfo("md5", "d3b07384d113edec49eaa6238ad5ff00"), + ) + + result = smap.bulk_remote_exists([entry]) + # no remote storage, should be skipped + assert entry not in result + + def test_bulk_exists_multiple_storages(self, make_odb): + cache1 = make_odb() + cache1.add_bytes("hash1", b"data1") + cache2 = make_odb() + cache2.add_bytes("hash2", b"data2") + + smap = StorageMapping() + smap.add_cache(ObjectStorage(key=(), odb=cache1)) + smap.add_cache(ObjectStorage(key=("subdir",), odb=cache2)) + + entry1 = DataIndexEntry( + key=("foo",), + hash_info=HashInfo("md5", "hash1"), + ) + entry2 = DataIndexEntry( + key=("subdir", "bar"), + hash_info=HashInfo("md5", "hash2"), + ) + + result = smap.bulk_cache_exists([entry1, entry2]) + assert result[entry1] is True + assert result[entry2] is True + + def test_bulk_exists_shared_odb(self, make_odb): + odb = make_odb() + odb.add_bytes("hash1", b"data1") + odb.add_bytes("hash2", b"data2") + + smap = StorageMapping() + # two logical storages, one physical ODB + smap.add_cache(ObjectStorage(key=(), odb=odb)) + smap.add_cache(ObjectStorage(key=("subdir",), odb=odb)) + + entry1 = DataIndexEntry( + key=("foo",), + hash_info=HashInfo("md5", "hash1"), + ) + entry2 = DataIndexEntry( + key=("subdir", "bar"), + hash_info=HashInfo("md5", "hash2"), + ) + + result = smap.bulk_cache_exists([entry1, entry2]) + assert result[entry1] is True + assert result[entry2] is True + + def test_bulk_cache_exists_with_file_storage(self, tmp_path): + (tmp_path / "foo.txt").write_text("hello") + fs = LocalFileSystem() + + smap = StorageMapping() + smap.add_cache(FileStorage(key=(), fs=fs, path=str(tmp_path))) + + entry_exists = DataIndexEntry(key=("foo.txt",)) + entry_not_exists = DataIndexEntry(key=("bar.txt",)) + + result = smap.bulk_cache_exists([entry_exists, entry_not_exists]) + assert result[entry_exists] is True + assert result[entry_not_exists] is False