Skip to content

Commit

Permalink
dvc: use HashInfo (#4495)
Browse files Browse the repository at this point in the history
* dvc: use HashInfo

Related to #4144 , #3069 , #1676

* Update dvc/tree/s3.py

Co-authored-by: Saugat Pachhai <[email protected]>

Co-authored-by: Saugat Pachhai <[email protected]>
  • Loading branch information
efiop and skshetry authored Aug 30, 2020
1 parent 69c63a8 commit 1a35128
Show file tree
Hide file tree
Showing 21 changed files with 89 additions and 76 deletions.
14 changes: 6 additions & 8 deletions dvc/cache/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,10 +132,8 @@ def changed(self, path_info, hash_info):
logger.debug("cache for '%s'('%s') has changed.", path_info, hash_)
return True

typ, actual = self.tree.get_hash(path_info)
assert typ == self.tree.PARAM_CHECKSUM

if hash_ != actual:
actual = self.tree.get_hash(path_info)
if hash_ != actual.value:
logger.debug(
"hash value '%s' for '%s' has changed (actual '%s').",
hash_,
Expand Down Expand Up @@ -319,7 +317,7 @@ def changed_cache_file(self, hash_):
)
return False

_, actual = self.tree.get_hash(cache_info)
actual = self.tree.get_hash(cache_info)

logger.debug(
"cache '%s' expected '%s' actual '%s'", cache_info, hash_, actual,
Expand All @@ -328,7 +326,7 @@ def changed_cache_file(self, hash_):
if not hash_ or not actual:
return True

if actual.split(".")[0] == hash_.split(".")[0]:
if actual.value.split(".")[0] == hash_.split(".")[0]:
# making cache file read-only so we don't need to check it
# next time
self.tree.protect(cache_info)
Expand Down Expand Up @@ -634,5 +632,5 @@ def merge(self, ancestor_info, our_info, their_info):
their = self.get_dir_cache(their_hash)

merged = self._merge_dirs(ancestor, our, their)
typ, merged_hash = self.tree.save_dir_info(merged)
return {typ: merged_hash}
hash_info = self.tree.save_dir_info(merged)
return {hash_info.name: hash_info.value}
8 changes: 4 additions & 4 deletions dvc/cache/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,14 +91,14 @@ def hashes_exist(
def already_cached(self, path_info):
assert path_info.scheme in ["", "local"]

typ, current_md5 = self.tree.get_hash(path_info)
current = self.tree.get_hash(path_info)

assert typ == "md5"
assert current.name == "md5"

if not current_md5:
if not current:
return False

return not self.changed_cache(current_md5)
return not self.changed_cache(current.value)

def _verify_link(self, path_info, link_type):
if link_type == "hardlink" and self.tree.getsize(path_info) == 0:
Expand Down
2 changes: 1 addition & 1 deletion dvc/dependency/repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def _get_checksum(self, locked=True):
if tree.isdir(path):
return self.repo.cache.local.tree.get_hash(
path, tree=tree
)[1]
).value
return tree.get_file_hash(path)

def workspace_status(self):
Expand Down
10 changes: 10 additions & 0 deletions dvc/hash_info.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from dataclasses import dataclass


@dataclass
class HashInfo:
name: str
value: str

def __bool__(self):
return bool(self.value)
2 changes: 1 addition & 1 deletion dvc/output/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def checksum(self):
return self.info.get(self.tree.PARAM_CHECKSUM)

def get_checksum(self):
return self.tree.get_hash(self.path_info)[1]
return self.tree.get_hash(self.path_info).value

@property
def is_dir_checksum(self):
Expand Down
2 changes: 1 addition & 1 deletion dvc/repo/diff.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def _to_path(output):

def _to_checksum(output):
if on_working_tree:
return self.cache.local.tree.get_hash(output.path_info)[1]
return self.cache.local.tree.get_hash(output.path_info).value
return output.checksum

def _exists(output):
Expand Down
3 changes: 2 additions & 1 deletion dvc/tree/azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from funcy import cached_property, wrap_prop

from dvc.hash_info import HashInfo
from dvc.path_info import CloudURLInfo
from dvc.progress import Tqdm
from dvc.scheme import Schemes
Expand Down Expand Up @@ -153,7 +154,7 @@ def remove(self, path_info):
).delete_blob()

def get_file_hash(self, path_info):
return self.PARAM_CHECKSUM, self.get_etag(path_info)
return HashInfo(self.PARAM_CHECKSUM, self.get_etag(path_info))

def _upload(
self, from_file, to_info, name=None, no_progress_bar=False, **_kwargs
Expand Down
38 changes: 19 additions & 19 deletions dvc/tree/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
DvcIgnoreInCollectedDirError,
RemoteCacheRequiredError,
)
from dvc.hash_info import HashInfo
from dvc.ignore import DvcIgnore
from dvc.path_info import PathInfo, URLInfo
from dvc.progress import Tqdm
Expand Down Expand Up @@ -242,7 +243,7 @@ def get_hash(self, path_info, **kwargs):
)

if not self.exists(path_info):
return self.PARAM_CHECKSUM, None
return None

# pylint: disable=assignment-from-none
hash_ = self.state.get(path_info)
Expand All @@ -260,17 +261,17 @@ def get_hash(self, path_info, **kwargs):
hash_ = None

if hash_:
return self.PARAM_CHECKSUM, hash_
return HashInfo(self.PARAM_CHECKSUM, hash_)

if self.isdir(path_info):
typ, hash_ = self.get_dir_hash(path_info, **kwargs)
hash_info = self.get_dir_hash(path_info, **kwargs)
else:
typ, hash_ = self.get_file_hash(path_info)
hash_info = self.get_file_hash(path_info)

if hash_ and self.exists(path_info):
self.state.save(path_info, hash_)
if hash_info and self.exists(path_info):
self.state.save(path_info, hash_info.value)

return typ, hash_
return hash_info

def get_file_hash(self, path_info):
raise NotImplementedError
Expand All @@ -294,8 +295,8 @@ def path_to_hash(self, path):
return "".join(parts)

def save_info(self, path_info, **kwargs):
typ, hash_ = self.get_hash(path_info, **kwargs)
return {typ: hash_}
hash_info = self.get_hash(path_info, **kwargs)
return {hash_info.name: hash_info.value}

def _calculate_hashes(self, file_infos):
file_infos = list(file_infos)
Expand All @@ -306,9 +307,7 @@ def _calculate_hashes(self, file_infos):
) as pbar:
worker = pbar.wrap_fn(self.get_file_hash)
with ThreadPoolExecutor(max_workers=self.hash_jobs) as executor:
hashes = (
value for typ, value in executor.map(worker, file_infos)
)
hashes = (hi.value for hi in executor.map(worker, file_infos))
return dict(zip(file_infos, hashes))

def _collect_dir(self, path_info, **kwargs):
Expand Down Expand Up @@ -346,17 +345,17 @@ def _collect_dir(self, path_info, **kwargs):
return sorted(result, key=itemgetter(self.PARAM_RELPATH))

def save_dir_info(self, dir_info):
typ, hash_, tmp_info = self._get_dir_info_hash(dir_info)
new_info = self.cache.tree.hash_to_path_info(hash_)
if self.cache.changed_cache_file(hash_):
hash_info, tmp_info = self._get_dir_info_hash(dir_info)
new_info = self.cache.tree.hash_to_path_info(hash_info.value)
if self.cache.changed_cache_file(hash_info.value):
self.cache.tree.makedirs(new_info.parent)
self.cache.tree.move(
tmp_info, new_info, mode=self.cache.CACHE_MODE
)

self.state.save(new_info, hash_)
self.state.save(new_info, hash_info.value)

return typ, hash_
return hash_info

def _get_dir_info_hash(self, dir_info):
# Sorting the list by path to ensure reproducibility
Expand All @@ -371,8 +370,9 @@ def _get_dir_info_hash(self, dir_info):
to_info = tree.path_info / tmp_fname("")
tree.upload(from_info, to_info, no_progress_bar=True)

typ, hash_ = tree.get_file_hash(to_info)
return typ, hash_ + self.CHECKSUM_DIR_SUFFIX, to_info
hash_info = tree.get_file_hash(to_info)
hash_info.value += self.CHECKSUM_DIR_SUFFIX
return hash_info, to_info

def upload(self, from_info, to_info, name=None, no_progress_bar=False):
if not hasattr(self, "_upload"):
Expand Down
7 changes: 4 additions & 3 deletions dvc/tree/dvc.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os

from dvc.exceptions import OutputNotFoundError
from dvc.hash_info import HashInfo
from dvc.path_info import PathInfo

from ._metadata import Metadata
Expand Down Expand Up @@ -245,7 +246,7 @@ def get_dir_hash(self, path_info, **kwargs):
out = outs[0]
# other code expects us to fetch the dir at this point
self._fetch_dir(out, **kwargs)
return out.tree.PARAM_CHECKSUM, out.checksum
return HashInfo(out.tree.PARAM_CHECKSUM, out.checksum)
except OutputNotFoundError:
pass

Expand All @@ -257,11 +258,11 @@ def get_file_hash(self, path_info):
raise OutputNotFoundError
out = outs[0]
if out.is_dir_checksum:
return (
return HashInfo(
out.tree.PARAM_CHECKSUM,
self._get_granular_checksum(path_info, out),
)
return out.tree.PARAM_CHECKSUM, out.checksum
return HashInfo(out.tree.PARAM_CHECKSUM, out.checksum)

def metadata(self, path_info):
path_info = PathInfo(os.path.abspath(path_info))
Expand Down
5 changes: 3 additions & 2 deletions dvc/tree/gs.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from funcy import cached_property, wrap_prop

from dvc.exceptions import DvcException
from dvc.hash_info import HashInfo
from dvc.path_info import CloudURLInfo
from dvc.progress import Tqdm
from dvc.scheme import Schemes
Expand Down Expand Up @@ -189,11 +190,11 @@ def get_file_hash(self, path_info):
path = path_info.path
blob = self.gs.bucket(bucket).get_blob(path)
if not blob:
return self.PARAM_CHECKSUM, None
return HashInfo(self.PARAM_CHECKSUM, None)

b64_md5 = blob.md5_hash
md5 = base64.b64decode(b64_md5)
return (
return HashInfo(
self.PARAM_CHECKSUM,
codecs.getencoder("hex")(md5)[0].decode("utf-8"),
)
Expand Down
5 changes: 4 additions & 1 deletion dvc/tree/hdfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from contextlib import closing, contextmanager
from urllib.parse import urlparse

from dvc.hash_info import HashInfo
from dvc.scheme import Schemes
from dvc.utils import fix_env, tmp_fname

Expand Down Expand Up @@ -175,7 +176,9 @@ def get_file_hash(self, path_info):
stdout = self.hadoop_fs(
f"checksum {path_info.url}", user=path_info.user
)
return self.PARAM_CHECKSUM, self._group(regex, stdout, "checksum")
return HashInfo(
self.PARAM_CHECKSUM, self._group(regex, stdout, "checksum")
)

def _upload(self, from_file, to_info, **_kwargs):
with self.hdfs(to_info) as hdfs:
Expand Down
3 changes: 2 additions & 1 deletion dvc/tree/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import dvc.prompt as prompt
from dvc.exceptions import DvcException, HTTPError
from dvc.hash_info import HashInfo
from dvc.path_info import HTTPURLInfo
from dvc.progress import Tqdm
from dvc.scheme import Schemes
Expand Down Expand Up @@ -151,7 +152,7 @@ def get_file_hash(self, path_info):
"Content-MD5 header for '{url}'".format(url=url)
)

return self.PARAM_CHECKSUM, etag
return HashInfo(self.PARAM_CHECKSUM, etag)

def _download(self, from_info, to_file, name=None, no_progress_bar=False):
response = self.request("GET", from_info.url, stream=True)
Expand Down
3 changes: 2 additions & 1 deletion dvc/tree/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from shortuuid import uuid

from dvc.exceptions import DvcException
from dvc.hash_info import HashInfo
from dvc.path_info import PathInfo
from dvc.scheme import Schemes
from dvc.system import System
Expand Down Expand Up @@ -309,7 +310,7 @@ def is_protected(self, path_info):
return stat.S_IMODE(mode) == self.CACHE_MODE

def get_file_hash(self, path_info):
return self.PARAM_CHECKSUM, file_md5(path_info)[0]
return HashInfo(self.PARAM_CHECKSUM, file_md5(path_info)[0])

@staticmethod
def getsize(path_info):
Expand Down
3 changes: 2 additions & 1 deletion dvc/tree/repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from dvc.dvcfile import is_valid_filename
from dvc.exceptions import OutputNotFoundError
from dvc.hash_info import HashInfo
from dvc.path_info import PathInfo
from dvc.utils import file_md5, is_exec
from dvc.utils.fs import copy_fobj_to_file, makedirs
Expand Down Expand Up @@ -332,7 +333,7 @@ def get_file_hash(self, path_info):
return dvc_tree.get_file_hash(path_info)
except OutputNotFoundError:
pass
return self.PARAM_CHECKSUM, file_md5(path_info, self)[0]
return HashInfo(self.PARAM_CHECKSUM, file_md5(path_info, self)[0])

def copytree(self, top, dest):
top = PathInfo(top)
Expand Down
6 changes: 2 additions & 4 deletions dvc/tree/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from dvc.config import ConfigError
from dvc.exceptions import DvcException, ETagMismatchError
from dvc.hash_info import HashInfo
from dvc.path_info import CloudURLInfo
from dvc.progress import Tqdm
from dvc.scheme import Schemes
Expand Down Expand Up @@ -332,10 +333,7 @@ def _copy(cls, s3, from_info, to_info, extra_args):

def get_file_hash(self, path_info):
with self._get_obj(path_info) as obj:
return (
self.PARAM_CHECKSUM,
obj.e_tag.strip('"'),
)
return HashInfo(self.PARAM_CHECKSUM, obj.e_tag.strip('"'))

def _upload(self, from_file, to_info, name=None, no_progress_bar=False):
with self._get_obj(to_info) as obj:
Expand Down
3 changes: 2 additions & 1 deletion dvc/tree/ssh/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from funcy import first, memoize, silent, wrap_with

import dvc.prompt as prompt
from dvc.hash_info import HashInfo
from dvc.scheme import Schemes

from ..base import BaseTree
Expand Down Expand Up @@ -238,7 +239,7 @@ def get_file_hash(self, path_info):
raise NotImplementedError

with self.ssh(path_info) as ssh:
return self.PARAM_CHECKSUM, ssh.md5(path_info.path)
return HashInfo(self.PARAM_CHECKSUM, ssh.md5(path_info.path))

def getsize(self, path_info):
with self.ssh(path_info) as ssh:
Expand Down
3 changes: 2 additions & 1 deletion dvc/tree/webdav.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from dvc.config import ConfigError
from dvc.exceptions import DvcException
from dvc.hash_info import HashInfo
from dvc.path_info import HTTPURLInfo, WebDAVURLInfo
from dvc.progress import Tqdm
from dvc.scheme import Schemes
Expand Down Expand Up @@ -142,7 +143,7 @@ def get_file_hash(self, path_info):
"Content-MD5 header for '{url}'".format(url=path_info.url)
)

return self.PARAM_CHECKSUM, etag
return HashInfo(self.PARAM_CHECKSUM, etag)

# Checks whether path points to directory
def isdir(self, path_info):
Expand Down
2 changes: 1 addition & 1 deletion tests/func/test_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ def test_repotree_cache_save(tmp_dir, dvc, scm, erepo_dir, local_cloud):
# into dvc.cache, not fetched or streamed from a remote
tree = RepoTree(erepo_dir.dvc, stream=True)
expected = [
tree.get_file_hash(PathInfo(erepo_dir / path))[1]
tree.get_file_hash(PathInfo(erepo_dir / path)).value
for path in ("dir/bar", "dir/subdir/foo")
]

Expand Down
Loading

0 comments on commit 1a35128

Please sign in to comment.