Skip to content

Commit 4f3fb15

Browse files
authored
dvcfs: optimize get() by reducing index.info calls() (#10540)
1 parent 016f285 commit 4f3fb15

File tree

5 files changed

+1007
-26
lines changed

5 files changed

+1007
-26
lines changed

dvc/dependency/repo.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -98,12 +98,13 @@ def download(self, to: "Output", jobs: Optional[int] = None):
9898

9999
files = super().download(to=to, jobs=jobs)
100100
if not isinstance(to.fs, LocalFileSystem):
101-
return files
101+
return
102102

103103
hashes: list[tuple[str, HashInfo, dict[str, Any]]] = []
104-
for src_path, dest_path in files:
104+
for src_path, dest_path, *rest in files:
105105
try:
106-
hash_info = self.fs.info(src_path)["dvc_info"]["entry"].hash_info
106+
info = rest[0] if rest else self.fs.info(src_path)
107+
hash_info = info["dvc_info"]["entry"].hash_info
107108
dest_info = to.fs.info(dest_path)
108109
except (KeyError, AttributeError):
109110
# If no hash info found, just keep going and output will be hashed later
@@ -112,7 +113,6 @@ def download(self, to: "Output", jobs: Optional[int] = None):
112113
hashes.append((dest_path, hash_info, dest_info))
113114
cache = to.cache if to.use_cache else to.local_cache
114115
cache.state.save_many(hashes, to.fs)
115-
return files
116116

117117
def update(self, rev: Optional[str] = None):
118118
if rev:

dvc/fs/__init__.py

+14-11
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import glob
2-
from typing import Optional
2+
from typing import Optional, Union
33
from urllib.parse import urlparse
44

55
from dvc.config import ConfigError as RepoConfigError
@@ -47,12 +47,24 @@
4747

4848
def download(
4949
fs: "FileSystem", fs_path: str, to: str, jobs: Optional[int] = None
50-
) -> list[tuple[str, str]]:
50+
) -> list[Union[tuple[str, str], tuple[str, str, dict]]]:
5151
from dvc.scm import lfs_prefetch
5252

5353
from .callbacks import TqdmCallback
5454

5555
with TqdmCallback(desc=f"Downloading {fs.name(fs_path)}", unit="files") as cb:
56+
if isinstance(fs, DVCFileSystem):
57+
lfs_prefetch(
58+
fs,
59+
[
60+
f"{fs.normpath(glob.escape(fs_path))}/**"
61+
if fs.isdir(fs_path)
62+
else glob.escape(fs_path)
63+
],
64+
)
65+
if not glob.has_magic(fs_path):
66+
return fs._get(fs_path, to, batch_size=jobs, callback=cb)
67+
5668
# NOTE: We use dvc-objects generic.copy over fs.get since it makes file
5769
# download atomic and avoids fsspec glob/regex path expansion.
5870
if fs.isdir(fs_path):
@@ -69,15 +81,6 @@ def download(
6981
from_infos = [fs_path]
7082
to_infos = [to]
7183

72-
if isinstance(fs, DVCFileSystem):
73-
lfs_prefetch(
74-
fs,
75-
[
76-
f"{fs.normpath(glob.escape(fs_path))}/**"
77-
if fs.isdir(fs_path)
78-
else glob.escape(fs_path)
79-
],
80-
)
8184
cb.set_size(len(from_infos))
8285
jobs = jobs or fs.jobs
8386
generic.copy(fs, from_infos, localfs, to_infos, callback=cb, batch_size=jobs)

dvc/fs/dvc.py

+146-2
Original file line numberDiff line numberDiff line change
@@ -6,20 +6,24 @@
66
import threading
77
from collections import deque
88
from contextlib import ExitStack, suppress
9+
from glob import has_magic
910
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
1011

11-
from fsspec.spec import AbstractFileSystem
12+
from fsspec.spec import DEFAULT_CALLBACK, AbstractFileSystem
1213
from funcy import wrap_with
1314

1415
from dvc.log import logger
15-
from dvc_objects.fs.base import FileSystem
16+
from dvc.utils.threadpool import ThreadPoolExecutor
17+
from dvc_objects.fs.base import AnyFSPath, FileSystem
1618

1719
from .data import DataFileSystem
1820

1921
if TYPE_CHECKING:
2022
from dvc.repo import Repo
2123
from dvc.types import DictStrAny, StrPath
2224

25+
from .callbacks import Callback
26+
2327
logger = logger.getChild(__name__)
2428

2529
RepoFactory = Union[Callable[..., "Repo"], type["Repo"]]
@@ -474,9 +478,110 @@ def _info( # noqa: C901
474478
info["name"] = path
475479
return info
476480

481+
def get(
482+
self,
483+
rpath,
484+
lpath,
485+
recursive=False,
486+
callback=DEFAULT_CALLBACK,
487+
maxdepth=None,
488+
batch_size=None,
489+
**kwargs,
490+
):
491+
self._get(
492+
rpath,
493+
lpath,
494+
recursive=recursive,
495+
callback=callback,
496+
maxdepth=maxdepth,
497+
batch_size=batch_size,
498+
**kwargs,
499+
)
500+
501+
def _get( # noqa: C901
502+
self,
503+
rpath,
504+
lpath,
505+
recursive=False,
506+
callback=DEFAULT_CALLBACK,
507+
maxdepth=None,
508+
batch_size=None,
509+
**kwargs,
510+
) -> list[Union[tuple[str, str], tuple[str, str, dict]]]:
511+
if (
512+
isinstance(rpath, list)
513+
or isinstance(lpath, list)
514+
or has_magic(rpath)
515+
or not self.exists(rpath)
516+
or not recursive
517+
):
518+
super().get(
519+
rpath,
520+
lpath,
521+
recursive=recursive,
522+
callback=callback,
523+
maxdepth=maxdepth,
524+
**kwargs,
525+
)
526+
return []
527+
528+
if os.path.isdir(lpath) or lpath.endswith(os.path.sep):
529+
lpath = self.join(lpath, os.path.basename(rpath))
530+
531+
if self.isfile(rpath):
532+
with callback.branched(rpath, lpath) as child:
533+
self.get_file(rpath, lpath, callback=child, **kwargs)
534+
return [(rpath, lpath)]
535+
536+
_files = []
537+
_dirs: list[str] = []
538+
for root, dirs, files in self.walk(rpath, maxdepth=maxdepth, detail=True):
539+
if files:
540+
callback.set_size((callback.size or 0) + len(files))
541+
542+
parts = self.relparts(root, rpath)
543+
if parts in ((os.curdir,), ("",)):
544+
parts = ()
545+
dest_root = os.path.join(lpath, *parts)
546+
if not maxdepth or len(parts) < maxdepth - 1:
547+
_dirs.extend(f"{dest_root}{os.path.sep}{d}" for d in dirs)
548+
549+
key = self._get_key_from_relative(root)
550+
_, dvc_fs, _ = self._get_subrepo_info(key)
551+
552+
for name, info in files.items():
553+
src_path = f"{root}{self.sep}{name}"
554+
dest_path = f"{dest_root}{os.path.sep}{name}"
555+
_files.append((dvc_fs, src_path, dest_path, info))
556+
557+
os.makedirs(lpath, exist_ok=True)
558+
for d in _dirs:
559+
os.mkdir(d)
560+
561+
def _get_file(arg):
562+
dvc_fs, src, dest, info = arg
563+
dvc_info = info.get("dvc_info")
564+
if dvc_info and dvc_fs:
565+
dvc_path = dvc_info["name"]
566+
dvc_fs.get_file(
567+
dvc_path, dest, callback=callback, info=dvc_info, **kwargs
568+
)
569+
else:
570+
self.get_file(src, dest, callback=callback, **kwargs)
571+
return src, dest, info
572+
573+
with ThreadPoolExecutor(max_workers=batch_size) as executor:
574+
return list(executor.imap_unordered(_get_file, _files))
575+
477576
def get_file(self, rpath, lpath, **kwargs):
478577
key = self._get_key_from_relative(rpath)
479578
fs_path = self._from_key(key)
579+
580+
dirpath = os.path.dirname(lpath)
581+
if dirpath:
582+
# makedirs raises error if the string is empty
583+
os.makedirs(dirpath, exist_ok=True)
584+
480585
try:
481586
return self.repo.fs.get_file(fs_path, lpath, **kwargs)
482587
except FileNotFoundError:
@@ -553,6 +658,45 @@ def immutable(self):
553658
def getcwd(self):
554659
return self.fs.getcwd()
555660

661+
def _get(
662+
self,
663+
from_info: Union[AnyFSPath, list[AnyFSPath]],
664+
to_info: Union[AnyFSPath, list[AnyFSPath]],
665+
callback: "Callback" = DEFAULT_CALLBACK,
666+
recursive: bool = False,
667+
batch_size: Optional[int] = None,
668+
**kwargs,
669+
) -> list[Union[tuple[str, str], tuple[str, str, dict]]]:
670+
# FileSystem.get is non-recursive by default if arguments are lists
671+
# otherwise, it's recursive.
672+
recursive = not (isinstance(from_info, list) and isinstance(to_info, list))
673+
return self.fs._get(
674+
from_info,
675+
to_info,
676+
callback=callback,
677+
recursive=recursive,
678+
batch_size=batch_size,
679+
**kwargs,
680+
)
681+
682+
def get(
683+
self,
684+
from_info: Union[AnyFSPath, list[AnyFSPath]],
685+
to_info: Union[AnyFSPath, list[AnyFSPath]],
686+
callback: "Callback" = DEFAULT_CALLBACK,
687+
recursive: bool = False,
688+
batch_size: Optional[int] = None,
689+
**kwargs,
690+
) -> None:
691+
self._get(
692+
from_info,
693+
to_info,
694+
callback=callback,
695+
batch_size=batch_size,
696+
recursive=recursive,
697+
**kwargs,
698+
)
699+
556700
@property
557701
def fsid(self) -> str:
558702
return self.fs.fsid

tests/func/test_import.py

+38-9
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from dvc.testing.tmp_dir import make_subrepo
1414
from dvc.utils.fs import remove
1515
from dvc_data.hashfile import hash
16-
from dvc_data.index.index import DataIndexDirError
16+
from dvc_data.index.index import DataIndex, DataIndexDirError
1717

1818

1919
def test_import(tmp_dir, scm, dvc, erepo_dir):
@@ -725,12 +725,41 @@ def test_import_invalid_configs(tmp_dir, scm, dvc, erepo_dir):
725725
)
726726

727727

728-
def test_import_no_hash(tmp_dir, scm, dvc, erepo_dir, mocker):
728+
@pytest.mark.parametrize(
729+
"files,expected_info_calls",
730+
[
731+
({"foo": "foo"}, {("foo",)}),
732+
(
733+
{
734+
"dir": {
735+
"bar": "bar",
736+
"subdir": {"lorem": "ipsum", "nested": {"lorem": "lorem"}},
737+
}
738+
},
739+
# info calls should be made for only directories
740+
{("dir",), ("dir", "subdir"), ("dir", "subdir", "nested")},
741+
),
742+
],
743+
)
744+
def test_import_no_hash(
745+
tmp_dir, scm, dvc, erepo_dir, mocker, files, expected_info_calls
746+
):
729747
with erepo_dir.chdir():
730-
erepo_dir.dvc_gen("foo", "foo content", commit="create foo")
731-
732-
spy = mocker.spy(hash, "file_md5")
733-
stage = dvc.imp(os.fspath(erepo_dir), "foo", "foo_imported")
734-
assert spy.call_count == 1
735-
for call in spy.call_args_list:
736-
assert stage.outs[0].fs_path != call.args[0]
748+
erepo_dir.dvc_gen(files, commit="create foo")
749+
750+
file_md5_spy = mocker.spy(hash, "file_md5")
751+
index_info_spy = mocker.spy(DataIndex, "info")
752+
name = next(iter(files))
753+
754+
dvc.imp(os.fspath(erepo_dir), name, "out")
755+
756+
local_hashes = [
757+
call.args[0]
758+
for call in file_md5_spy.call_args_list
759+
if call.args[1].protocol == "local"
760+
]
761+
# no files should be hashed, should use existing metadata
762+
assert not local_hashes
763+
assert {
764+
call.args[1] for call in index_info_spy.call_args_list
765+
} == expected_info_calls

0 commit comments

Comments
 (0)