Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Make Gluon download function to be atomic #12572

Merged
merged 38 commits into from
Oct 12, 2018
Merged
Show file tree
Hide file tree
Changes from 32 commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
b7fc61d
use rename trick to achieve atomic write but didn't support python2 a…
stu1130 Sep 14, 2018
657a921
add test for multiprocess download
stu1130 Sep 14, 2018
353221a
implement atomic_replace referred by https://github.com/untitaker/pyt…
stu1130 Sep 15, 2018
32ff8cd
change the number of testing process to 10
stu1130 Sep 15, 2018
7ded536
add docstring and disable linter
stu1130 Sep 15, 2018
b4c6b1f
half way to address some issue reviewer have
stu1130 Sep 18, 2018
901ed34
use warning instead of raise UserWarn
stu1130 Sep 18, 2018
a9d40c4
Merge branch 'master' of https://github.com/apache/incubator-mxnet
stu1130 Sep 18, 2018
91f9ca9
check for sha1
stu1130 Sep 19, 2018
9ab6341
Trigger CI
stu1130 Sep 19, 2018
1985726
fix the logic of checking hash
stu1130 Sep 19, 2018
429684a
Merge branch 'master' of https://github.com/apache/incubator-mxnet
stu1130 Sep 19, 2018
df2a9ef
refine the error message
stu1130 Sep 20, 2018
602b7d0
Merge branch 'master' of https://github.com/apache/incubator-mxnet
stu1130 Sep 20, 2018
3bed37c
add more comments and expose the error message to the user
stu1130 Sep 21, 2018
f0b8c5d
delete trailing whitespace
stu1130 Sep 21, 2018
47f65c2
rename _path_to_encode to _str_to_unicode
stu1130 Sep 22, 2018
85f9408
Merge branch 'master' of https://github.com/apache/incubator-mxnet
stu1130 Sep 27, 2018
5a01234
Merge branch 'master' of https://github.com/apache/incubator-mxnet
stu1130 Sep 28, 2018
8f58746
Merge branch 'master' of https://github.com/apache/incubator-mxnet
stu1130 Sep 28, 2018
b4235f0
fix the error message bug and add remove when the movefile fail on wi…
stu1130 Oct 2, 2018
f0776b9
add remove temp file for non-windows os
stu1130 Oct 2, 2018
c7756df
handle the OSError caused by os.remove
stu1130 Oct 2, 2018
971ecf9
Trigger CI
stu1130 Oct 2, 2018
1531594
Merge branch 'master' of https://github.com/apache/incubator-mxnet
stu1130 Oct 2, 2018
4dec67c
Merge branch 'master' into fix_race_condion_download
stu1130 Oct 2, 2018
3e799de
Merge branch 'master' of https://github.com/apache/incubator-mxnet
stu1130 Oct 8, 2018
af5b8ec
use finally to raise failure of atomic replace
stu1130 Oct 8, 2018
25e21b6
Merge commit '5314cf4742767319ce356bd5154c6885380e0d5c' into fix_race…
stu1130 Oct 8, 2018
f690c69
Merge branch 'master' of https://github.com/apache/incubator-mxnet
stu1130 Oct 8, 2018
0212e5a
Merge branch 'master' into fix_race_condion_download
stu1130 Oct 8, 2018
6af8945
add missing try except block for os.remove
stu1130 Oct 9, 2018
a091a97
Merge branch 'master' of https://github.com/apache/incubator-mxnet
stu1130 Oct 9, 2018
6765862
add retries value to error message
stu1130 Oct 10, 2018
f1359be
Merge branch 'master' of https://github.com/apache/incubator-mxnet
stu1130 Oct 11, 2018
a7b7b12
Merge branch 'master' into fix_race_condion_download
stu1130 Oct 11, 2018
e360853
Merge branch 'master' of https://github.com/apache/incubator-mxnet
stu1130 Oct 11, 2018
3662d38
Merge branch 'master' into fix_race_condion_download
stu1130 Oct 11, 2018
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 86 additions & 11 deletions python/mxnet/gluon/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@
'check_sha1', 'download']

import os
import sys
import hashlib
import uuid
import warnings
import collections
import weakref
Expand Down Expand Up @@ -195,6 +197,62 @@ def check_sha1(filename, sha1_hash):
return sha1.hexdigest() == sha1_hash


if not sys.platform.startswith('win32'):
# refer to https://github.com/untitaker/python-atomicwrites
def _replace_atomic(src, dst):
"""Implement atomic os.replace with linux and OSX. Internal use only"""
try:
os.rename(src, dst)
except OSError:
try:
os.remove(src)
except OSError:
pass
finally:
raise OSError(
'Moving downloaded temp file - {}, to {} failed. \
Please retry the download.'.format(src, dst))
else:
import ctypes

_MOVEFILE_REPLACE_EXISTING = 0x1
# Setting this value guarantees that a move performed as a copy
# and delete operation is flushed to disk before the function returns.
# The flush occurs at the end of the copy operation.
_MOVEFILE_WRITE_THROUGH = 0x8
_windows_default_flags = _MOVEFILE_WRITE_THROUGH

text_type = unicode if sys.version_info[0] == 2 else str # noqa

def _str_to_unicode(x):
"""Handle text decoding. Internal use only"""
if not isinstance(x, text_type):
return x.decode(sys.getfilesystemencoding())
return x

def _handle_errors(rv, src):
"""Handle WinError. Internal use only"""
if not rv:
msg = ctypes.FormatError(ctypes.GetLastError())
# if the MoveFileExW fails(e.g. fail to acquire file lock), removes the tempfile
try:
os.remove(src)
except OSError:
pass
finally:
raise OSError(msg)

def _replace_atomic(src, dst):
"""Implement atomic os.replace with windows.
refer to https://docs.microsoft.com/en-us/windows/desktop/api/winbase/nf-winbase-movefileexw
The function fails when one of the process(copy, flush, delete) fails.
Internal use only"""
_handle_errors(ctypes.windll.kernel32.MoveFileExW(
_str_to_unicode(src), _str_to_unicode(dst),
_windows_default_flags | _MOVEFILE_REPLACE_EXISTING
), src)


def download(url, path=None, overwrite=False, sha1_hash=None, retries=5, verify_ssl=True):
"""Download an given URL

Expand Down Expand Up @@ -231,7 +289,7 @@ def download(url, path=None, overwrite=False, sha1_hash=None, retries=5, verify_
fname = os.path.join(path, url.split('/')[-1])
else:
fname = path
assert retries >= 0, "Number of retries should be at least 0"
assert retries >= 0, 'Number of retries should be at least 0'
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

include the value of retries in the error message.


if not verify_ssl:
warnings.warn(
Expand All @@ -242,31 +300,48 @@ def download(url, path=None, overwrite=False, sha1_hash=None, retries=5, verify_
dirname = os.path.dirname(os.path.abspath(os.path.expanduser(fname)))
if not os.path.exists(dirname):
os.makedirs(dirname)
while retries+1 > 0:
while retries + 1 > 0:
# Disable pyling too broad Exception
# pylint: disable=W0703
try:
print('Downloading %s from %s...'%(fname, url))
print('Downloading {} from {}...'.format(fname, url))
r = requests.get(url, stream=True, verify=verify_ssl)
if r.status_code != 200:
raise RuntimeError("Failed downloading url %s"%url)
with open(fname, 'wb') as f:
raise RuntimeError('Failed downloading url {}'.format(url))
# create uuid for temporary files
random_uuid = str(uuid.uuid4())
with open('{}.{}'.format(fname, random_uuid), 'wb') as f:
for chunk in r.iter_content(chunk_size=1024):
if chunk: # filter out keep-alive new chunks
f.write(chunk)
# if the target file exists(created by other processes)
# and have the same hash with target file
# delete the temporary file
if not os.path.exists(fname) or (sha1_hash and not check_sha1(fname, sha1_hash)):
# atmoic operation in the same file system
_replace_atomic('{}.{}'.format(fname, random_uuid), fname)
else:
try:
os.remove('{}.{}'.format(fname, random_uuid))
except OSError:
pass
finally:
warnings.warn(
'File {} exists in file system so the downloaded file is deleted'.format(fname))
if sha1_hash and not check_sha1(fname, sha1_hash):
raise UserWarning('File {} is downloaded but the content hash does not match.'\
' The repo may be outdated or download may be incomplete. '\
'If the "repo_url" is overridden, consider switching to '\
'the default repo.'.format(fname))
raise UserWarning(
'File {} is downloaded but the content hash does not match.'
' The repo may be outdated or download may be incomplete. '
'If the "repo_url" is overridden, consider switching to '
'the default repo.'.format(fname))
break
except Exception as e:
retries -= 1
if retries <= 0:
raise e
else:
print("download failed, retrying, {} attempt{} left"
.format(retries, 's' if retries > 1 else ''))
print('download failed due to {}, retrying, {} attempt{} left'
.format(repr(e), retries, 's' if retries > 1 else ''))

return fname

Expand Down
46 changes: 40 additions & 6 deletions tests/python/unittest/test_gluon_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
import os
import tempfile
import warnings
import glob
import shutil
import multiprocessing as mp

try:
from unittest import mock
Expand Down Expand Up @@ -46,22 +49,53 @@ def test_download_retries():

@mock.patch(
'requests.get',
mock.Mock(side_effect=
lambda *args, **kwargs: MockResponse(200, 'MOCK CONTENT' * 100)))
mock.Mock(side_effect=lambda *args, **kwargs: MockResponse(200, 'MOCK CONTENT' * 100)))
def _download_successful(tmp):
""" internal use for testing download successfully """
mx.gluon.utils.download(
"https://raw.githubusercontent.com/apache/incubator-mxnet/master/README.md",
path=tmp)


def test_download_successful():
""" test download with one process """
tmp = tempfile.mkdtemp()
tmpfile = os.path.join(tmp, 'README.md')
mx.gluon.utils.download(
"https://raw.githubusercontent.com/apache/incubator-mxnet/master/README.md",
path=tmpfile)
assert os.path.getsize(tmpfile) > 100
_download_successful(tmpfile)
assert os.path.getsize(tmpfile) > 100, os.path.getsize(tmpfile)
pattern = os.path.join(tmp, 'README.md*')
# check only one file we want left
assert len(glob.glob(pattern)) == 1, glob.glob(pattern)
# delete temp dir
shutil.rmtree(tmp)


def test_multiprocessing_download_successful():
""" test download with multiprocessing """
tmp = tempfile.mkdtemp()
tmpfile = os.path.join(tmp, 'README.md')
process_list = []
# test it with 10 processes
for i in range(10):
process_list.append(mp.Process(
target=_download_successful, args=(tmpfile,)))
process_list[i].start()
for i in range(10):
process_list[i].join()
assert os.path.getsize(tmpfile) > 100, os.path.getsize(tmpfile)
# check only one file we want left
pattern = os.path.join(tmp, 'README.md*')
assert len(glob.glob(pattern)) == 1, glob.glob(pattern)
# delete temp dir
shutil.rmtree(tmp)


@mock.patch(
'requests.get',
mock.Mock(
side_effect=lambda *args, **kwargs: MockResponse(200, 'MOCK CONTENT')))
def test_download_ssl_verify():
""" test download verify_ssl parameter """
with warnings.catch_warnings(record=True) as warnings_:
mx.gluon.utils.download(
"https://mxnet.incubator.apache.org/index.html", verify_ssl=False)
Expand Down