Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Make Gluon download function to be atomic #12572

Merged
merged 38 commits into from
Oct 12, 2018
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
b7fc61d
use rename trick to achieve atomic write but didn't support python2 a…
stu1130 Sep 14, 2018
657a921
add test for multiprocess download
stu1130 Sep 14, 2018
353221a
implement atomic_replace referred by https://github.com/untitaker/pyt…
stu1130 Sep 15, 2018
32ff8cd
change the number of testing process to 10
stu1130 Sep 15, 2018
7ded536
add docstring and disable linter
stu1130 Sep 15, 2018
b4c6b1f
half way to address some issue reviewer have
stu1130 Sep 18, 2018
901ed34
use warning instead of raise UserWarn
stu1130 Sep 18, 2018
a9d40c4
Merge branch 'master' of https://github.com/apache/incubator-mxnet
stu1130 Sep 18, 2018
91f9ca9
check for sha1
stu1130 Sep 19, 2018
9ab6341
Trigger CI
stu1130 Sep 19, 2018
1985726
fix the logic of checking hash
stu1130 Sep 19, 2018
429684a
Merge branch 'master' of https://github.com/apache/incubator-mxnet
stu1130 Sep 19, 2018
df2a9ef
refine the error message
stu1130 Sep 20, 2018
602b7d0
Merge branch 'master' of https://github.com/apache/incubator-mxnet
stu1130 Sep 20, 2018
3bed37c
add more comments and expose the error message to the user
stu1130 Sep 21, 2018
f0b8c5d
delete trailing whitespace
stu1130 Sep 21, 2018
47f65c2
rename _path_to_encode to _str_to_unicode
stu1130 Sep 22, 2018
85f9408
Merge branch 'master' of https://github.com/apache/incubator-mxnet
stu1130 Sep 27, 2018
5a01234
Merge branch 'master' of https://github.com/apache/incubator-mxnet
stu1130 Sep 28, 2018
8f58746
Merge branch 'master' of https://github.com/apache/incubator-mxnet
stu1130 Sep 28, 2018
b4235f0
fix the error message bug and add remove when the movefile fail on wi…
stu1130 Oct 2, 2018
f0776b9
add remove temp file for non-windows os
stu1130 Oct 2, 2018
c7756df
handle the OSError caused by os.remove
stu1130 Oct 2, 2018
971ecf9
Trigger CI
stu1130 Oct 2, 2018
1531594
Merge branch 'master' of https://github.com/apache/incubator-mxnet
stu1130 Oct 2, 2018
4dec67c
Merge branch 'master' into fix_race_condion_download
stu1130 Oct 2, 2018
3e799de
Merge branch 'master' of https://github.com/apache/incubator-mxnet
stu1130 Oct 8, 2018
af5b8ec
use finally to raise failure of atomic replace
stu1130 Oct 8, 2018
25e21b6
Merge commit '5314cf4742767319ce356bd5154c6885380e0d5c' into fix_race…
stu1130 Oct 8, 2018
f690c69
Merge branch 'master' of https://github.com/apache/incubator-mxnet
stu1130 Oct 8, 2018
0212e5a
Merge branch 'master' into fix_race_condion_download
stu1130 Oct 8, 2018
6af8945
add missing try except block for os.remove
stu1130 Oct 9, 2018
a091a97
Merge branch 'master' of https://github.com/apache/incubator-mxnet
stu1130 Oct 9, 2018
6765862
add retries value to error message
stu1130 Oct 10, 2018
f1359be
Merge branch 'master' of https://github.com/apache/incubator-mxnet
stu1130 Oct 11, 2018
a7b7b12
Merge branch 'master' into fix_race_condion_download
stu1130 Oct 11, 2018
e360853
Merge branch 'master' of https://github.com/apache/incubator-mxnet
stu1130 Oct 11, 2018
3662d38
Merge branch 'master' into fix_race_condion_download
stu1130 Oct 11, 2018
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 55 additions & 9 deletions python/mxnet/gluon/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@
'check_sha1', 'download']

import os
import sys
import hashlib
import uuid
import warnings
import collections
import weakref
Expand Down Expand Up @@ -195,6 +197,39 @@ def check_sha1(filename, sha1_hash):
return sha1.hexdigest() == sha1_hash


if sys.platform != 'win32':
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use startswith method as recommended here https://docs.python.org/2/library/sys.html#sys.platform and how will cygwin platform be handled?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for raising the question. cygwin meets POSIX requirement and rename is guaranteed to be atomic in POSIX

# refer to https://github.com/untitaker/python-atomicwrites
def _replace_atomic(src, dst):
"""Implement atomic os.replace with linux and OSX. Internal Usa only"""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Internal use only

os.rename(src, dst)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you put this in a try catch block to handle OSError

else:
from ctypes import windll, WinError

_MOVEFILE_REPLACE_EXISTING = 0x1
_MOVEFILE_WRITE_THROUGH = 0x8
_windows_default_flags = _MOVEFILE_WRITE_THROUGH

text_type = unicode if sys.version_info[0] == 2 else str # noqa

def _path_to_unicode(x):
"""Handle text decoding. Internal use only"""
if not isinstance(x, text_type):
return x.decode(sys.getfilesystemencoding())
return x

def _handle_errors(rv):
"""Handle WinError. Internal use only"""
if not rv:
raise WinError()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems not informative. what could be the causes? what should users do when this happens?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch! will add more comments in docstring


def _replace_atomic(src, dst):
"""Implement atomic os.replace with windows. Internal use only"""
_handle_errors(windll.kernel32.MoveFileExW(
_path_to_unicode(src), _path_to_unicode(dst),
_windows_default_flags | _MOVEFILE_REPLACE_EXISTING
))


def download(url, path=None, overwrite=False, sha1_hash=None, retries=5, verify_ssl=True):
"""Download an given URL

Expand Down Expand Up @@ -231,7 +266,7 @@ def download(url, path=None, overwrite=False, sha1_hash=None, retries=5, verify_
fname = os.path.join(path, url.split('/')[-1])
else:
fname = path
assert retries >= 0, "Number of retries should be at least 0"
assert retries >= 0, 'Number of retries should be at least 0'
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

include the value of retries in the error message.


if not verify_ssl:
warnings.warn(
Expand All @@ -242,23 +277,34 @@ def download(url, path=None, overwrite=False, sha1_hash=None, retries=5, verify_
dirname = os.path.dirname(os.path.abspath(os.path.expanduser(fname)))
if not os.path.exists(dirname):
os.makedirs(dirname)
while retries+1 > 0:
while retries + 1 > 0:
# Disable pyling too broad Exception
# pylint: disable=W0703
try:
print('Downloading %s from %s...'%(fname, url))
print('Downloading {} from {}...'.format(fname, url))
r = requests.get(url, stream=True, verify=verify_ssl)
if r.status_code != 200:
raise RuntimeError("Failed downloading url %s"%url)
with open(fname, 'wb') as f:
raise RuntimeError('Failed downloading url {}'.format(url))
# create uuid for temporary files
random_uuid = str(uuid.uuid4())
with open('{}.{}'.format(fname, random_uuid), 'wb') as f:
# create uuid for temporary files
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: repeating comment, remove line 291

for chunk in r.iter_content(chunk_size=1024):
if chunk: # filter out keep-alive new chunks
f.write(chunk)
# if the target file exists(created by other processes),
# delete the temporary file
if os.path.exists(fname):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

check for file hash too, as the existing file might be an aborted incomplete download.

os.remove('{}.{}'.format(fname, random_uuid))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here if fname local path exists, then the newly downloaded file is silently deleted. We should probably raise a user warning saying that a local file by the same name exists and the downloaded file has not been saved.

It might also be better to first check for the presence of local file and then make the HTTP call in line 285. This way we will prevent making unnecessary calls, instead of actually downloading the file and then deleting it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will add UserWarning and I checked the target file existence before starting downloading at line 280

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so if the check exists there in line 280. why are we repeating it here. What purpose does this check serve?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Under multiprocessing environment, two processes might pass the check and one of them finish the download, rename. while the other one is still downloading or right before rename. In such case, I choose to remove instead of rename because there is a rare case where rename might be not atomic in windows based on some research on MSDN. it uses remove to mitigate the risk

else:
# atmoic operation in the same file system
_replace_atomic('{}.{}'.format(fname, random_uuid), fname)
if sha1_hash and not check_sha1(fname, sha1_hash):
raise UserWarning('File {} is downloaded but the content hash does not match.'\
' The repo may be outdated or download may be incomplete. '\
'If the "repo_url" is overridden, consider switching to '\
'the default repo.'.format(fname))
raise UserWarning(
'File {} is downloaded but the content hash does not match.'
' The repo may be outdated or download may be incomplete. '
'If the "repo_url" is overridden, consider switching to '
'the default repo.'.format(fname))
break
except Exception as e:
retries -= 1
Expand Down
46 changes: 40 additions & 6 deletions tests/python/unittest/test_gluon_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
import os
import tempfile
import warnings
import glob
import shutil
import multiprocessing as mp

try:
from unittest import mock
Expand Down Expand Up @@ -46,22 +49,53 @@ def test_download_retries():

@mock.patch(
'requests.get',
mock.Mock(side_effect=
lambda *args, **kwargs: MockResponse(200, 'MOCK CONTENT' * 100)))
mock.Mock(side_effect=lambda *args, **kwargs: MockResponse(200, 'MOCK CONTENT' * 100)))
def _download_successful(tmp):
""" internal use for testing download successfully """
mx.gluon.utils.download(
"https://raw.githubusercontent.com/apache/incubator-mxnet/master/README.md",
path=tmp)


def test_download_successful():
""" test download with one process """
tmp = tempfile.mkdtemp()
tmpfile = os.path.join(tmp, 'README.md')
mx.gluon.utils.download(
"https://raw.githubusercontent.com/apache/incubator-mxnet/master/README.md",
path=tmpfile)
assert os.path.getsize(tmpfile) > 100
_download_successful(tmpfile)
assert os.path.getsize(tmpfile) > 100, os.path.getsize(tmpfile)
pattern = os.path.join(tmp, 'README.md*')
# check only one file we want left
assert len(glob.glob(pattern)) == 1, glob.glob(pattern)
# delete temp dir
shutil.rmtree(tmp)


def test_multiprocessing_download_successful():
""" test download with multiprocessing """
tmp = tempfile.mkdtemp()
tmpfile = os.path.join(tmp, 'README.md')
process_list = []
# test it with 10 processes
for i in range(10):
process_list.append(mp.Process(
target=_download_successful, args=(tmpfile,)))
process_list[i].start()
for i in range(10):
process_list[i].join()
assert os.path.getsize(tmpfile) > 100, os.path.getsize(tmpfile)
# check only one file we want left
pattern = os.path.join(tmp, 'README.md*')
assert len(glob.glob(pattern)) == 1, glob.glob(pattern)
# delete temp dir
shutil.rmtree(tmp)


@mock.patch(
'requests.get',
mock.Mock(
side_effect=lambda *args, **kwargs: MockResponse(200, 'MOCK CONTENT')))
def test_download_ssl_verify():
""" test download verify_ssl parameter """
with warnings.catch_warnings(record=True) as warnings_:
mx.gluon.utils.download(
"https://mxnet.incubator.apache.org/index.html", verify_ssl=False)
Expand Down