Skip to content

Commit

Permalink
Make Gluon download function to be atomic (apache#12572)
Browse files Browse the repository at this point in the history
* use rename trick to achieve atomic write but didn't support python2 and windows

* add test for multiprocess download

* implement atomic_replace referred by https://github.com/untitaker/python-atomicwrites

* change the number of testing process to 10

* add docstring and disable linter

* half way to address some issue reviewer have

* use warning instead of raise UserWarn

* check for sha1

* Trigger CI

* fix the logic of checking hash

* refine the error message

* add more comments and expose the error message to the user

* delete trailing whitespace

* rename _path_to_encode to _str_to_unicode

* fix the error message bug and add remove when the movefile fail on windows

* add remove temp file for non-windows os

* handle the OSError caused by os.remove

* Trigger CI

* use finally to raise failure of atomic replace

* add missing try except block for os.remove

* add retries value to error message
  • Loading branch information
stu1130 authored and piyushghai committed Oct 19, 2018
1 parent 9c2b614 commit 751f40e
Show file tree
Hide file tree
Showing 2 changed files with 127 additions and 17 deletions.
98 changes: 87 additions & 11 deletions python/mxnet/gluon/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@
'check_sha1', 'download']

import os
import sys
import hashlib
import uuid
import warnings
import collections
import weakref
Expand Down Expand Up @@ -195,6 +197,62 @@ def check_sha1(filename, sha1_hash):
return sha1.hexdigest() == sha1_hash


if not sys.platform.startswith('win32'):
# refer to https://github.com/untitaker/python-atomicwrites
def _replace_atomic(src, dst):
"""Implement atomic os.replace with linux and OSX. Internal use only"""
try:
os.rename(src, dst)
except OSError:
try:
os.remove(src)
except OSError:
pass
finally:
raise OSError(
'Moving downloaded temp file - {}, to {} failed. \
Please retry the download.'.format(src, dst))
else:
import ctypes

_MOVEFILE_REPLACE_EXISTING = 0x1
# Setting this value guarantees that a move performed as a copy
# and delete operation is flushed to disk before the function returns.
# The flush occurs at the end of the copy operation.
_MOVEFILE_WRITE_THROUGH = 0x8
_windows_default_flags = _MOVEFILE_WRITE_THROUGH

text_type = unicode if sys.version_info[0] == 2 else str # noqa

def _str_to_unicode(x):
"""Handle text decoding. Internal use only"""
if not isinstance(x, text_type):
return x.decode(sys.getfilesystemencoding())
return x

def _handle_errors(rv, src):
"""Handle WinError. Internal use only"""
if not rv:
msg = ctypes.FormatError(ctypes.GetLastError())
# if the MoveFileExW fails(e.g. fail to acquire file lock), removes the tempfile
try:
os.remove(src)
except OSError:
pass
finally:
raise OSError(msg)

def _replace_atomic(src, dst):
"""Implement atomic os.replace with windows.
refer to https://docs.microsoft.com/en-us/windows/desktop/api/winbase/nf-winbase-movefileexw
The function fails when one of the process(copy, flush, delete) fails.
Internal use only"""
_handle_errors(ctypes.windll.kernel32.MoveFileExW(
_str_to_unicode(src), _str_to_unicode(dst),
_windows_default_flags | _MOVEFILE_REPLACE_EXISTING
), src)


def download(url, path=None, overwrite=False, sha1_hash=None, retries=5, verify_ssl=True):
"""Download an given URL
Expand Down Expand Up @@ -231,7 +289,8 @@ def download(url, path=None, overwrite=False, sha1_hash=None, retries=5, verify_
fname = os.path.join(path, url.split('/')[-1])
else:
fname = path
assert retries >= 0, "Number of retries should be at least 0"
assert retries >= 0, "Number of retries should be at least 0, currently it's {}".format(
retries)

if not verify_ssl:
warnings.warn(
Expand All @@ -242,31 +301,48 @@ def download(url, path=None, overwrite=False, sha1_hash=None, retries=5, verify_
dirname = os.path.dirname(os.path.abspath(os.path.expanduser(fname)))
if not os.path.exists(dirname):
os.makedirs(dirname)
while retries+1 > 0:
while retries + 1 > 0:
# Disable pyling too broad Exception
# pylint: disable=W0703
try:
print('Downloading %s from %s...'%(fname, url))
print('Downloading {} from {}...'.format(fname, url))
r = requests.get(url, stream=True, verify=verify_ssl)
if r.status_code != 200:
raise RuntimeError("Failed downloading url %s"%url)
with open(fname, 'wb') as f:
raise RuntimeError('Failed downloading url {}'.format(url))
# create uuid for temporary files
random_uuid = str(uuid.uuid4())
with open('{}.{}'.format(fname, random_uuid), 'wb') as f:
for chunk in r.iter_content(chunk_size=1024):
if chunk: # filter out keep-alive new chunks
f.write(chunk)
# if the target file exists(created by other processes)
# and have the same hash with target file
# delete the temporary file
if not os.path.exists(fname) or (sha1_hash and not check_sha1(fname, sha1_hash)):
# atmoic operation in the same file system
_replace_atomic('{}.{}'.format(fname, random_uuid), fname)
else:
try:
os.remove('{}.{}'.format(fname, random_uuid))
except OSError:
pass
finally:
warnings.warn(
'File {} exists in file system so the downloaded file is deleted'.format(fname))
if sha1_hash and not check_sha1(fname, sha1_hash):
raise UserWarning('File {} is downloaded but the content hash does not match.'\
' The repo may be outdated or download may be incomplete. '\
'If the "repo_url" is overridden, consider switching to '\
'the default repo.'.format(fname))
raise UserWarning(
'File {} is downloaded but the content hash does not match.'
' The repo may be outdated or download may be incomplete. '
'If the "repo_url" is overridden, consider switching to '
'the default repo.'.format(fname))
break
except Exception as e:
retries -= 1
if retries <= 0:
raise e
else:
print("download failed, retrying, {} attempt{} left"
.format(retries, 's' if retries > 1 else ''))
print('download failed due to {}, retrying, {} attempt{} left'
.format(repr(e), retries, 's' if retries > 1 else ''))

return fname

Expand Down
46 changes: 40 additions & 6 deletions tests/python/unittest/test_gluon_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
import os
import tempfile
import warnings
import glob
import shutil
import multiprocessing as mp

try:
from unittest import mock
Expand Down Expand Up @@ -46,22 +49,53 @@ def test_download_retries():

@mock.patch(
'requests.get',
mock.Mock(side_effect=
lambda *args, **kwargs: MockResponse(200, 'MOCK CONTENT' * 100)))
mock.Mock(side_effect=lambda *args, **kwargs: MockResponse(200, 'MOCK CONTENT' * 100)))
def _download_successful(tmp):
""" internal use for testing download successfully """
mx.gluon.utils.download(
"https://raw.githubusercontent.com/apache/incubator-mxnet/master/README.md",
path=tmp)


def test_download_successful():
""" test download with one process """
tmp = tempfile.mkdtemp()
tmpfile = os.path.join(tmp, 'README.md')
mx.gluon.utils.download(
"https://raw.githubusercontent.com/apache/incubator-mxnet/master/README.md",
path=tmpfile)
assert os.path.getsize(tmpfile) > 100
_download_successful(tmpfile)
assert os.path.getsize(tmpfile) > 100, os.path.getsize(tmpfile)
pattern = os.path.join(tmp, 'README.md*')
# check only one file we want left
assert len(glob.glob(pattern)) == 1, glob.glob(pattern)
# delete temp dir
shutil.rmtree(tmp)


def test_multiprocessing_download_successful():
""" test download with multiprocessing """
tmp = tempfile.mkdtemp()
tmpfile = os.path.join(tmp, 'README.md')
process_list = []
# test it with 10 processes
for i in range(10):
process_list.append(mp.Process(
target=_download_successful, args=(tmpfile,)))
process_list[i].start()
for i in range(10):
process_list[i].join()
assert os.path.getsize(tmpfile) > 100, os.path.getsize(tmpfile)
# check only one file we want left
pattern = os.path.join(tmp, 'README.md*')
assert len(glob.glob(pattern)) == 1, glob.glob(pattern)
# delete temp dir
shutil.rmtree(tmp)


@mock.patch(
'requests.get',
mock.Mock(
side_effect=lambda *args, **kwargs: MockResponse(200, 'MOCK CONTENT')))
def test_download_ssl_verify():
""" test download verify_ssl parameter """
with warnings.catch_warnings(record=True) as warnings_:
mx.gluon.utils.download(
"https://mxnet.incubator.apache.org/index.html", verify_ssl=False)
Expand Down

0 comments on commit 751f40e

Please sign in to comment.