Skip to content

Commit

Permalink
tree: make get_hash return type and hash pair (#4397)
Browse files Browse the repository at this point in the history
Currently we kinda assume that whatever is returned by `get_file_hash`
is of type self.PARAM_CHECKSUM, which is not actually true. E.g. for
http it might return `etag` or `md5`, but we don't distinguish between
those and call both `etag`. This is becoming more relevant for dir
hashes that are computed a few different ways (e.g. in-memory md5 or
upload to remote and get etag for the dir file).

Prerequisite for #4144 and #3069
  • Loading branch information
efiop authored Aug 15, 2020
1 parent b773ba7 commit 023dec4
Show file tree
Hide file tree
Showing 17 changed files with 51 additions and 34 deletions.
8 changes: 5 additions & 3 deletions dvc/cache/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,9 @@ def changed(self, path_info, hash_info):
logger.debug("cache for '%s'('%s') has changed.", path_info, hash_)
return True

actual = self.tree.get_hash(path_info)
typ, actual = self.tree.get_hash(path_info)
assert typ == self.tree.PARAM_CHECKSUM

if hash_ != actual:
logger.debug(
"hash value '%s' for '%s' has changed (actual '%s').",
Expand Down Expand Up @@ -312,7 +314,7 @@ def changed_cache_file(self, hash_):
)
return False

actual = self.tree.get_hash(cache_info)
_, actual = self.tree.get_hash(cache_info)

logger.debug(
"cache '%s' expected '%s' actual '%s'", cache_info, hash_, actual,
Expand Down Expand Up @@ -358,7 +360,7 @@ def changed_cache(self, hash_, path_info=None, filter_info=None):
return self.changed_cache_file(hash_)

def already_cached(self, path_info):
current = self.tree.get_hash(path_info)
_, current = self.tree.get_hash(path_info)

if not current:
return False
Expand Down
4 changes: 3 additions & 1 deletion dvc/cache/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,9 @@ def hashes_exist(
def already_cached(self, path_info):
assert path_info.scheme in ["", "local"]

current_md5 = self.tree.get_hash(path_info)
typ, current_md5 = self.tree.get_hash(path_info)

assert typ == "md5"

if not current_md5:
return False
Expand Down
4 changes: 3 additions & 1 deletion dvc/dependency/repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,9 @@ def _get_checksum(self, locked=True):

# We are polluting our repo cache with some dir listing here
if tree.isdir(path):
return self.repo.cache.local.tree.get_hash(path, tree=tree)
return self.repo.cache.local.tree.get_hash(
path, tree=tree
)[1]
return tree.get_file_hash(path)

def workspace_status(self):
Expand Down
2 changes: 1 addition & 1 deletion dvc/output/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def checksum(self, checksum):
self.info[self.tree.PARAM_CHECKSUM] = checksum

def get_checksum(self):
return self.tree.get_hash(self.path_info)
return self.tree.get_hash(self.path_info)[1]

@property
def is_dir_checksum(self):
Expand Down
2 changes: 1 addition & 1 deletion dvc/repo/diff.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def _to_path(output):

def _to_checksum(output):
if on_working_tree:
return self.cache.local.tree.get_hash(output.path_info)
return self.cache.local.tree.get_hash(output.path_info)[1]
return output.checksum

def _exists(output):
Expand Down
9 changes: 6 additions & 3 deletions dvc/repo/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,8 +238,11 @@ def get_file_hash(self, path_info):
raise OutputNotFoundError
out = outs[0]
if out.is_dir_checksum:
return self._get_granular_checksum(path_info, out)
return out.checksum
return (
out.tree.PARAM_CHECKSUM,
self._get_granular_checksum(path_info, out),
)
return out.tree.PARAM_CHECKSUM, out.checksum


class RepoTree(BaseTree): # pylint:disable=abstract-method
Expand Down Expand Up @@ -504,7 +507,7 @@ def get_file_hash(self, path_info):
return dvc_tree.get_file_hash(path_info)
except OutputNotFoundError:
pass
return file_md5(path_info, self)[0]
return self.PARAM_CHECKSUM, file_md5(path_info, self)[0]

def copytree(self, top, dest):
top = PathInfo(top)
Expand Down
2 changes: 1 addition & 1 deletion dvc/tree/azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def remove(self, path_info):
).delete_blob()

def get_file_hash(self, path_info):
return self.get_etag(path_info)
return self.PARAM_CHECKSUM, self.get_etag(path_info)

def _upload(
self, from_file, to_info, name=None, no_progress_bar=False, **_kwargs
Expand Down
28 changes: 15 additions & 13 deletions dvc/tree/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ def get_hash(self, path_info, **kwargs):
)

if not self.exists(path_info):
return None
return self.PARAM_CHECKSUM, None

# pylint: disable=assignment-from-none
hash_ = self.state.get(path_info)
Expand All @@ -260,17 +260,17 @@ def get_hash(self, path_info, **kwargs):
hash_ = None

if hash_:
return hash_
return self.PARAM_CHECKSUM, hash_

if self.isdir(path_info):
hash_ = self.get_dir_hash(path_info, **kwargs)
typ, hash_ = self.get_dir_hash(path_info, **kwargs)
else:
hash_ = self.get_file_hash(path_info)
typ, hash_ = self.get_file_hash(path_info)

if hash_ and self.exists(path_info):
self.state.save(path_info, hash_)

return hash_
return typ, hash_

def get_file_hash(self, path_info):
raise NotImplementedError
Expand All @@ -294,7 +294,8 @@ def path_to_hash(self, path):
return "".join(parts)

def save_info(self, path_info, **kwargs):
return {self.PARAM_CHECKSUM: self.get_hash(path_info, **kwargs)}
typ, hash_ = self.get_hash(path_info, **kwargs)
return {typ: hash_}

def _calculate_hashes(self, file_infos):
file_infos = list(file_infos)
Expand All @@ -305,9 +306,10 @@ def _calculate_hashes(self, file_infos):
) as pbar:
worker = pbar.wrap_fn(self.get_file_hash)
with ThreadPoolExecutor(max_workers=self.hash_jobs) as executor:
tasks = executor.map(worker, file_infos)
hashes = dict(zip(file_infos, tasks))
return hashes
hashes = (
value for typ, value in executor.map(worker, file_infos)
)
return dict(zip(file_infos, hashes))

def _collect_dir(self, path_info, **kwargs):
file_infos = set()
Expand Down Expand Up @@ -344,7 +346,7 @@ def _collect_dir(self, path_info, **kwargs):
return sorted(result, key=itemgetter(self.PARAM_RELPATH))

def _save_dir_info(self, dir_info):
hash_, tmp_info = self._get_dir_info_hash(dir_info)
typ, hash_, tmp_info = self._get_dir_info_hash(dir_info)
new_info = self.cache.tree.hash_to_path_info(hash_)
if self.cache.changed_cache_file(hash_):
self.cache.tree.makedirs(new_info.parent)
Expand All @@ -354,7 +356,7 @@ def _save_dir_info(self, dir_info):

self.state.save(new_info, hash_)

return hash_
return typ, hash_

def _get_dir_info_hash(self, dir_info):
tmp = tempfile.NamedTemporaryFile(delete=False).name
Expand All @@ -366,8 +368,8 @@ def _get_dir_info_hash(self, dir_info):
to_info = tree.path_info / tmp_fname("")
tree.upload(from_info, to_info, no_progress_bar=True)

hash_ = tree.get_file_hash(to_info) + self.CHECKSUM_DIR_SUFFIX
return hash_, to_info
typ, hash_ = tree.get_file_hash(to_info)
return typ, hash_ + self.CHECKSUM_DIR_SUFFIX, to_info

def upload(self, from_info, to_info, name=None, no_progress_bar=False):
if not hasattr(self, "_upload"):
Expand Down
7 changes: 5 additions & 2 deletions dvc/tree/gs.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,11 +190,14 @@ def get_file_hash(self, path_info):
path = path_info.path
blob = self.gs.bucket(bucket).get_blob(path)
if not blob:
return None
return self.PARAM_CHECKSUM, None

b64_md5 = blob.md5_hash
md5 = base64.b64decode(b64_md5)
return codecs.getencoder("hex")(md5)[0].decode("utf-8")
return (
self.PARAM_CHECKSUM,
codecs.getencoder("hex")(md5)[0].decode("utf-8"),
)

def _upload(self, from_file, to_info, name=None, no_progress_bar=False):
bucket = self.gs.bucket(to_info.bucket)
Expand Down
2 changes: 1 addition & 1 deletion dvc/tree/hdfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def get_file_hash(self, path_info):
stdout = self.hadoop_fs(
f"checksum {path_info.url}", user=path_info.user
)
return self._group(regex, stdout, "checksum")
return self.PARAM_CHECKSUM, self._group(regex, stdout, "checksum")

def _upload(self, from_file, to_info, **_kwargs):
with self.hdfs(to_info) as hdfs:
Expand Down
2 changes: 1 addition & 1 deletion dvc/tree/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def get_file_hash(self, path_info):
"Content-MD5 header for '{url}'".format(url=url)
)

return etag
return self.PARAM_CHECKSUM, etag

def _download(self, from_info, to_file, name=None, no_progress_bar=False):
response = self.request("GET", from_info.url, stream=True)
Expand Down
2 changes: 1 addition & 1 deletion dvc/tree/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@ def is_protected(self, path_info):
return stat.S_IMODE(mode) == self.CACHE_MODE

def get_file_hash(self, path_info):
return file_md5(path_info)[0]
return self.PARAM_CHECKSUM, file_md5(path_info)[0]

@staticmethod
def getsize(path_info):
Expand Down
5 changes: 4 additions & 1 deletion dvc/tree/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,10 @@ def _copy(cls, s3, from_info, to_info, extra_args):
raise ETagMismatchError(etag, cached_etag)

def get_file_hash(self, path_info):
return self.get_etag(self.s3, path_info.bucket, path_info.path)
return (
self.PARAM_CHECKSUM,
self.get_etag(self.s3, path_info.bucket, path_info.path),
)

def _upload(self, from_file, to_info, name=None, no_progress_bar=False):
total = os.path.getsize(from_file)
Expand Down
2 changes: 1 addition & 1 deletion dvc/tree/ssh/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ def get_file_hash(self, path_info):
raise NotImplementedError

with self.ssh(path_info) as ssh:
return ssh.md5(path_info.path)
return self.PARAM_CHECKSUM, ssh.md5(path_info.path)

def getsize(self, path_info):
with self.ssh(path_info) as ssh:
Expand Down
2 changes: 1 addition & 1 deletion dvc/tree/webdav.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def get_file_hash(self, path_info):
"Content-MD5 header for '{url}'".format(url=path_info.url)
)

return etag
return self.PARAM_CHECKSUM, etag

# Checks whether path points to directory
def isdir(self, path_info):
Expand Down
2 changes: 1 addition & 1 deletion tests/func/test_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ def test_repotree_cache_save(tmp_dir, dvc, scm, erepo_dir, local_cloud):
# into dvc.cache, not fetched or streamed from a remote
tree = RepoTree(erepo_dir.dvc, stream=True)
expected = [
tree.get_file_hash(PathInfo(erepo_dir / path))
tree.get_file_hash(PathInfo(erepo_dir / path))[1]
for path in ("dir/bar", "dir/subdir/foo")
]

Expand Down
2 changes: 1 addition & 1 deletion tests/unit/remote/test_azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def test_get_file_hash(tmp_dir, azure):
to_info = azure
tree.upload(PathInfo("foo"), to_info)
assert tree.exists(to_info)
hash_ = tree.get_file_hash(to_info)
_, hash_ = tree.get_file_hash(to_info)
assert hash_
assert isinstance(hash_, str)
assert hash_.strip("'").strip('"') == hash_

0 comments on commit 023dec4

Please sign in to comment.