diff --git a/python/mxnet/gluon/utils.py b/python/mxnet/gluon/utils.py index d5a14a6859a7..78324986760a 100644 --- a/python/mxnet/gluon/utils.py +++ b/python/mxnet/gluon/utils.py @@ -22,7 +22,9 @@ 'check_sha1', 'download'] import os +import sys import hashlib +import uuid import warnings import collections import weakref @@ -195,6 +197,62 @@ def check_sha1(filename, sha1_hash): return sha1.hexdigest() == sha1_hash +if not sys.platform.startswith('win32'): + # refer to https://github.com/untitaker/python-atomicwrites + def _replace_atomic(src, dst): + """Implement atomic os.replace with linux and OSX. Internal use only""" + try: + os.rename(src, dst) + except OSError: + try: + os.remove(src) + except OSError: + pass + finally: + raise OSError( + 'Moving downloaded temp file - {}, to {} failed. \ + Please retry the download.'.format(src, dst)) +else: + import ctypes + + _MOVEFILE_REPLACE_EXISTING = 0x1 + # Setting this value guarantees that a move performed as a copy + # and delete operation is flushed to disk before the function returns. + # The flush occurs at the end of the copy operation. + _MOVEFILE_WRITE_THROUGH = 0x8 + _windows_default_flags = _MOVEFILE_WRITE_THROUGH + + text_type = unicode if sys.version_info[0] == 2 else str # noqa + + def _str_to_unicode(x): + """Handle text decoding. Internal use only""" + if not isinstance(x, text_type): + return x.decode(sys.getfilesystemencoding()) + return x + + def _handle_errors(rv, src): + """Handle WinError. Internal use only""" + if not rv: + msg = ctypes.FormatError(ctypes.GetLastError()) + # if the MoveFileExW fails(e.g. fail to acquire file lock), removes the tempfile + try: + os.remove(src) + except OSError: + pass + finally: + raise OSError(msg) + + def _replace_atomic(src, dst): + """Implement atomic os.replace with windows. + refer to https://docs.microsoft.com/en-us/windows/desktop/api/winbase/nf-winbase-movefileexw + The function fails when one of the process(copy, flush, delete) fails. + Internal use only""" + _handle_errors(ctypes.windll.kernel32.MoveFileExW( + _str_to_unicode(src), _str_to_unicode(dst), + _windows_default_flags | _MOVEFILE_REPLACE_EXISTING + ), src) + + def download(url, path=None, overwrite=False, sha1_hash=None, retries=5, verify_ssl=True): """Download an given URL @@ -231,7 +289,8 @@ def download(url, path=None, overwrite=False, sha1_hash=None, retries=5, verify_ fname = os.path.join(path, url.split('/')[-1]) else: fname = path - assert retries >= 0, "Number of retries should be at least 0" + assert retries >= 0, "Number of retries should be at least 0, currently it's {}".format( + retries) if not verify_ssl: warnings.warn( @@ -242,31 +301,48 @@ def download(url, path=None, overwrite=False, sha1_hash=None, retries=5, verify_ dirname = os.path.dirname(os.path.abspath(os.path.expanduser(fname))) if not os.path.exists(dirname): os.makedirs(dirname) - while retries+1 > 0: + while retries + 1 > 0: # Disable pyling too broad Exception # pylint: disable=W0703 try: - print('Downloading %s from %s...'%(fname, url)) + print('Downloading {} from {}...'.format(fname, url)) r = requests.get(url, stream=True, verify=verify_ssl) if r.status_code != 200: - raise RuntimeError("Failed downloading url %s"%url) - with open(fname, 'wb') as f: + raise RuntimeError('Failed downloading url {}'.format(url)) + # create uuid for temporary files + random_uuid = str(uuid.uuid4()) + with open('{}.{}'.format(fname, random_uuid), 'wb') as f: for chunk in r.iter_content(chunk_size=1024): if chunk: # filter out keep-alive new chunks f.write(chunk) + # if the target file exists(created by other processes) + # and have the same hash with target file + # delete the temporary file + if not os.path.exists(fname) or (sha1_hash and not check_sha1(fname, sha1_hash)): + # atmoic operation in the same file system + _replace_atomic('{}.{}'.format(fname, random_uuid), fname) + else: + try: + os.remove('{}.{}'.format(fname, random_uuid)) + except OSError: + pass + finally: + warnings.warn( + 'File {} exists in file system so the downloaded file is deleted'.format(fname)) if sha1_hash and not check_sha1(fname, sha1_hash): - raise UserWarning('File {} is downloaded but the content hash does not match.'\ - ' The repo may be outdated or download may be incomplete. '\ - 'If the "repo_url" is overridden, consider switching to '\ - 'the default repo.'.format(fname)) + raise UserWarning( + 'File {} is downloaded but the content hash does not match.' + ' The repo may be outdated or download may be incomplete. ' + 'If the "repo_url" is overridden, consider switching to ' + 'the default repo.'.format(fname)) break except Exception as e: retries -= 1 if retries <= 0: raise e else: - print("download failed, retrying, {} attempt{} left" - .format(retries, 's' if retries > 1 else '')) + print('download failed due to {}, retrying, {} attempt{} left' + .format(repr(e), retries, 's' if retries > 1 else '')) return fname diff --git a/tests/python/unittest/test_gluon_utils.py b/tests/python/unittest/test_gluon_utils.py index 431852427f53..20f1c8c549ad 100644 --- a/tests/python/unittest/test_gluon_utils.py +++ b/tests/python/unittest/test_gluon_utils.py @@ -19,6 +19,9 @@ import os import tempfile import warnings +import glob +import shutil +import multiprocessing as mp try: from unittest import mock @@ -46,15 +49,45 @@ def test_download_retries(): @mock.patch( 'requests.get', - mock.Mock(side_effect= - lambda *args, **kwargs: MockResponse(200, 'MOCK CONTENT' * 100))) + mock.Mock(side_effect=lambda *args, **kwargs: MockResponse(200, 'MOCK CONTENT' * 100))) +def _download_successful(tmp): + """ internal use for testing download successfully """ + mx.gluon.utils.download( + "https://raw.githubusercontent.com/apache/incubator-mxnet/master/README.md", + path=tmp) + + def test_download_successful(): + """ test download with one process """ tmp = tempfile.mkdtemp() tmpfile = os.path.join(tmp, 'README.md') - mx.gluon.utils.download( - "https://raw.githubusercontent.com/apache/incubator-mxnet/master/README.md", - path=tmpfile) - assert os.path.getsize(tmpfile) > 100 + _download_successful(tmpfile) + assert os.path.getsize(tmpfile) > 100, os.path.getsize(tmpfile) + pattern = os.path.join(tmp, 'README.md*') + # check only one file we want left + assert len(glob.glob(pattern)) == 1, glob.glob(pattern) + # delete temp dir + shutil.rmtree(tmp) + + +def test_multiprocessing_download_successful(): + """ test download with multiprocessing """ + tmp = tempfile.mkdtemp() + tmpfile = os.path.join(tmp, 'README.md') + process_list = [] + # test it with 10 processes + for i in range(10): + process_list.append(mp.Process( + target=_download_successful, args=(tmpfile,))) + process_list[i].start() + for i in range(10): + process_list[i].join() + assert os.path.getsize(tmpfile) > 100, os.path.getsize(tmpfile) + # check only one file we want left + pattern = os.path.join(tmp, 'README.md*') + assert len(glob.glob(pattern)) == 1, glob.glob(pattern) + # delete temp dir + shutil.rmtree(tmp) @mock.patch( @@ -62,6 +95,7 @@ def test_download_successful(): mock.Mock( side_effect=lambda *args, **kwargs: MockResponse(200, 'MOCK CONTENT'))) def test_download_ssl_verify(): + """ test download verify_ssl parameter """ with warnings.catch_warnings(record=True) as warnings_: mx.gluon.utils.download( "https://mxnet.incubator.apache.org/index.html", verify_ssl=False)