From efebedf85b7d23b4962495ac62b826be5c71aa50 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Sat, 18 Jan 2020 07:59:33 +0000 Subject: [PATCH 1/5] use temp file --- python/mxnet/gluon/model_zoo/model_store.py | 18 +++-- python/mxnet/gluon/utils.py | 70 ++++++++++++------- tests/python/unittest/test_gluon_model_zoo.py | 17 +++++ 3 files changed, 73 insertions(+), 32 deletions(-) diff --git a/python/mxnet/gluon/model_zoo/model_store.py b/python/mxnet/gluon/model_zoo/model_store.py index 11ac47bae905..f19f23cf7763 100644 --- a/python/mxnet/gluon/model_zoo/model_store.py +++ b/python/mxnet/gluon/model_zoo/model_store.py @@ -22,8 +22,9 @@ import os import zipfile import logging +import tempfile -from ..utils import download, check_sha1 +from ..utils import download, check_sha1, replace_file from ... import base, util _model_sha1 = {name: checksum for checksum, name in [ @@ -107,12 +108,15 @@ def get_model_file(name, root=os.path.join(base.data_dir(), 'models')): repo_url = os.environ.get('MXNET_GLUON_REPO', apache_repo_url) if repo_url[-1] != '/': repo_url = repo_url + '/' - download(_url_format.format(repo_url=repo_url, file_name=file_name), - path=zip_file_path, - overwrite=True) - with zipfile.ZipFile(zip_file_path) as zf: - zf.extractall(root) - os.remove(zip_file_path) + + with tempfile.NamedTemporaryFile(dir=root) as zip_file: + download(_url_format.format(repo_url=repo_url, file_name=file_name), + path=zip_file.name, overwrite=True, inplace=True) + with zipfile.ZipFile(zip_file) as zf: + with tempfile.TemporaryDirectory(dir=root) as temp_dir: + zf.extractall(temp_dir) + temp_file_path = os.path.join(temp_dir, file_name+'.params') + replace_file(temp_file_path, file_path) if check_sha1(file_path, sha1_hash): return file_path diff --git a/python/mxnet/gluon/utils.py b/python/mxnet/gluon/utils.py index 83ed15aed450..f86f947c627e 100644 --- a/python/mxnet/gluon/utils.py +++ b/python/mxnet/gluon/utils.py @@ -21,7 +21,7 @@ from __future__ import absolute_import __all__ = ['split_data', 'split_and_load', 'clip_global_norm', - 'check_sha1', 'download'] + 'check_sha1', 'download', 'replace_file'] import os import sys @@ -35,7 +35,7 @@ import numpy as np from .. import ndarray -from ..util import is_np_shape, is_np_array +from ..util import is_np_shape, is_np_array, makedirs from .. import numpy as _mx_np # pylint: disable=reimported @@ -197,8 +197,14 @@ def check_sha1(filename, sha1_hash): if not sys.platform.startswith('win32'): # refer to https://github.com/untitaker/python-atomicwrites - def _replace_atomic(src, dst): - """Implement atomic os.replace with linux and OSX. Internal use only""" + def replace_file(src, dst): + """Implement atomic os.replace with linux and OSX. + + Parameters + ---------- + src : source file path + dst : destination file path + """ try: os.rename(src, dst) except OSError: @@ -240,18 +246,25 @@ def _handle_errors(rv, src): finally: raise OSError(msg) - def _replace_atomic(src, dst): + def replace_file(src, dst): """Implement atomic os.replace with windows. + refer to https://docs.microsoft.com/en-us/windows/desktop/api/winbase/nf-winbase-movefileexw The function fails when one of the process(copy, flush, delete) fails. - Internal use only""" + + Parameters + ---------- + src : source file path + dst : destination file path + """ _handle_errors(ctypes.windll.kernel32.MoveFileExW( _str_to_unicode(src), _str_to_unicode(dst), _windows_default_flags | _MOVEFILE_REPLACE_EXISTING ), src) -def download(url, path=None, overwrite=False, sha1_hash=None, retries=5, verify_ssl=True): +def download(url, path=None, overwrite=False, sha1_hash=None, retries=5, verify_ssl=True, + inplace=False): """Download an given URL Parameters @@ -270,6 +283,9 @@ def download(url, path=None, overwrite=False, sha1_hash=None, retries=5, verify_ The number of times to attempt the download in case of failure or non 200 return codes verify_ssl : bool, default True Verify SSL certificates. + inplace : bool, default False + Whether to write to the file at destination path inplace. Usually used if the path of temp + file is provided. Returns ------- @@ -298,7 +314,7 @@ def download(url, path=None, overwrite=False, sha1_hash=None, retries=5, verify_ if overwrite or not os.path.exists(fname) or (sha1_hash and not check_sha1(fname, sha1_hash)): dirname = os.path.dirname(os.path.abspath(os.path.expanduser(fname))) if not os.path.exists(dirname): - os.makedirs(dirname) + makedirs(dirname) while retries + 1 > 0: # Disable pyling too broad Exception # pylint: disable=W0703 @@ -307,26 +323,30 @@ def download(url, path=None, overwrite=False, sha1_hash=None, retries=5, verify_ r = requests.get(url, stream=True, verify=verify_ssl) if r.status_code != 200: raise RuntimeError('Failed downloading url {}'.format(url)) - # create uuid for temporary files - random_uuid = str(uuid.uuid4()) - with open('{}.{}'.format(fname, random_uuid), 'wb') as f: + if inplace: + temp_file_name = fname + else: + random_uuid = str(uuid.uuid4()) + temp_file_name = '{}.{}'.format(fname, random_uuid) + with open(temp_file_name, 'wb') as f: for chunk in r.iter_content(chunk_size=1024): if chunk: # filter out keep-alive new chunks f.write(chunk) - # if the target file exists(created by other processes) - # and have the same hash with target file - # delete the temporary file - if not os.path.exists(fname) or (sha1_hash and not check_sha1(fname, sha1_hash)): - # atmoic operation in the same file system - _replace_atomic('{}.{}'.format(fname, random_uuid), fname) - else: - try: - os.remove('{}.{}'.format(fname, random_uuid)) - except OSError: - pass - finally: - warnings.warn( - 'File {} exists in file system so the downloaded file is deleted'.format(fname)) + if not inplace: + # if the target file exists(created by other processes) + # and have the same hash with target file + # delete the temporary file + if not os.path.exists(fname) or (sha1_hash and not check_sha1(fname, sha1_hash)): + # atmoic operation in the same file system + replace_file('{}.{}'.format(fname, random_uuid), fname) + else: + try: + os.remove('{}.{}'.format(fname, random_uuid)) + except OSError: + pass + finally: + warnings.warn( + 'File {} exists in file system so the downloaded file is deleted'.format(fname)) if sha1_hash and not check_sha1(fname, sha1_hash): raise UserWarning( 'File {} is downloaded but the content hash does not match.' diff --git a/tests/python/unittest/test_gluon_model_zoo.py b/tests/python/unittest/test_gluon_model_zoo.py index a64668451a25..5a3d8f3f0b46 100644 --- a/tests/python/unittest/test_gluon_model_zoo.py +++ b/tests/python/unittest/test_gluon_model_zoo.py @@ -20,6 +20,7 @@ from mxnet.gluon.model_zoo.vision import get_model import sys from common import setup_module, with_seed, teardown +import multiprocessing def eprint(*args, **kwargs): @@ -50,6 +51,22 @@ def test_models(): model(mx.nd.random.uniform(shape=data_shape)).wait_to_read() +@with_seed() +def test_parallel_download(): + def fn(model_name): + model = get_model(model_name, pretrained=True, root='./parallel_download') + print(type(model)) + + processes = [] + name = 'mobilenetv2_0.25' + for _ in range(10): + p = multiprocessing.Process(target=fn, args=(name,)) + processes.append(p) + for p in processes: + p.start() + for p in processes: + p.join() + if __name__ == '__main__': import nose nose.runmodule() From c8578567d79a15bb8e14928116a4957f05427041 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Wed, 22 Jan 2020 23:53:40 +0000 Subject: [PATCH 2/5] fix dependency --- python/mxnet/gluon/model_zoo/model_store.py | 10 ++++++---- python/mxnet/gluon/utils.py | 2 +- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/python/mxnet/gluon/model_zoo/model_store.py b/python/mxnet/gluon/model_zoo/model_store.py index f19f23cf7763..5132f045a4b4 100644 --- a/python/mxnet/gluon/model_zoo/model_store.py +++ b/python/mxnet/gluon/model_zoo/model_store.py @@ -23,6 +23,7 @@ import zipfile import logging import tempfile +import shutil from ..utils import download, check_sha1, replace_file from ... import base, util @@ -113,10 +114,11 @@ def get_model_file(name, root=os.path.join(base.data_dir(), 'models')): download(_url_format.format(repo_url=repo_url, file_name=file_name), path=zip_file.name, overwrite=True, inplace=True) with zipfile.ZipFile(zip_file) as zf: - with tempfile.TemporaryDirectory(dir=root) as temp_dir: - zf.extractall(temp_dir) - temp_file_path = os.path.join(temp_dir, file_name+'.params') - replace_file(temp_file_path, file_path) + temp_dir = tempfile.mkdtemp(dir=root) + zf.extractall(temp_dir) + temp_file_path = os.path.join(temp_dir, file_name+'.params') + replace_file(temp_file_path, file_path) + shutil.rmtree(temp_dir) if check_sha1(file_path, sha1_hash): return file_path diff --git a/python/mxnet/gluon/utils.py b/python/mxnet/gluon/utils.py index f86f947c627e..695c8a0f64e7 100644 --- a/python/mxnet/gluon/utils.py +++ b/python/mxnet/gluon/utils.py @@ -265,7 +265,7 @@ def replace_file(src, dst): def download(url, path=None, overwrite=False, sha1_hash=None, retries=5, verify_ssl=True, inplace=False): - """Download an given URL + """Download a given URL Parameters ---------- From 224df46509b1c803e3f65bda244813ca5ce28693 Mon Sep 17 00:00:00 2001 From: Haibin Lin Date: Wed, 22 Jan 2020 20:58:05 -0800 Subject: [PATCH 3/5] Update model_store.py --- python/mxnet/gluon/model_zoo/model_store.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/mxnet/gluon/model_zoo/model_store.py b/python/mxnet/gluon/model_zoo/model_store.py index 5132f045a4b4..f437b632d430 100644 --- a/python/mxnet/gluon/model_zoo/model_store.py +++ b/python/mxnet/gluon/model_zoo/model_store.py @@ -105,7 +105,6 @@ def get_model_file(name, root=os.path.join(base.data_dir(), 'models')): util.makedirs(root) - zip_file_path = os.path.join(root, file_name+'.zip') repo_url = os.environ.get('MXNET_GLUON_REPO', apache_repo_url) if repo_url[-1] != '/': repo_url = repo_url + '/' From 7eaf7b6d8571b2b5afff368bf950cd1a451b34a6 Mon Sep 17 00:00:00 2001 From: Haibin Lin Date: Wed, 22 Jan 2020 20:58:51 -0800 Subject: [PATCH 4/5] Update test_gluon_model_zoo.py --- tests/python/unittest/test_gluon_model_zoo.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/tests/python/unittest/test_gluon_model_zoo.py b/tests/python/unittest/test_gluon_model_zoo.py index 5a3d8f3f0b46..d53dd403a5b8 100644 --- a/tests/python/unittest/test_gluon_model_zoo.py +++ b/tests/python/unittest/test_gluon_model_zoo.py @@ -50,17 +50,16 @@ def test_models(): model.collect_params().initialize() model(mx.nd.random.uniform(shape=data_shape)).wait_to_read() +def parallel_download(model_name): + model = get_model(model_name, pretrained=True, root='./parallel_download') + print(type(model)) @with_seed() def test_parallel_download(): - def fn(model_name): - model = get_model(model_name, pretrained=True, root='./parallel_download') - print(type(model)) - processes = [] name = 'mobilenetv2_0.25' for _ in range(10): - p = multiprocessing.Process(target=fn, args=(name,)) + p = multiprocessing.Process(target=parallel_download, args=(name,)) processes.append(p) for p in processes: p.start() From 7281767eaf5244bfd1036bce90c38fad12d112c2 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Thu, 23 Jan 2020 20:05:01 +0000 Subject: [PATCH 5/5] remove NamedTempFile --- python/mxnet/gluon/model_zoo/model_store.py | 21 +++++----- python/mxnet/gluon/utils.py | 44 +++++++++------------ 2 files changed, 30 insertions(+), 35 deletions(-) diff --git a/python/mxnet/gluon/model_zoo/model_store.py b/python/mxnet/gluon/model_zoo/model_store.py index 5132f045a4b4..389d8663e5af 100644 --- a/python/mxnet/gluon/model_zoo/model_store.py +++ b/python/mxnet/gluon/model_zoo/model_store.py @@ -23,6 +23,7 @@ import zipfile import logging import tempfile +import uuid import shutil from ..utils import download, check_sha1, replace_file @@ -110,15 +111,17 @@ def get_model_file(name, root=os.path.join(base.data_dir(), 'models')): if repo_url[-1] != '/': repo_url = repo_url + '/' - with tempfile.NamedTemporaryFile(dir=root) as zip_file: - download(_url_format.format(repo_url=repo_url, file_name=file_name), - path=zip_file.name, overwrite=True, inplace=True) - with zipfile.ZipFile(zip_file) as zf: - temp_dir = tempfile.mkdtemp(dir=root) - zf.extractall(temp_dir) - temp_file_path = os.path.join(temp_dir, file_name+'.params') - replace_file(temp_file_path, file_path) - shutil.rmtree(temp_dir) + random_uuid = str(uuid.uuid4()) + temp_zip_file_path = os.path.join(root, file_name+'.zip'+random_uuid) + download(_url_format.format(repo_url=repo_url, file_name=file_name), + path=temp_zip_file_path, overwrite=True) + with zipfile.ZipFile(temp_zip_file_path) as zf: + temp_dir = tempfile.mkdtemp(dir=root) + zf.extractall(temp_dir) + temp_file_path = os.path.join(temp_dir, file_name+'.params') + replace_file(temp_file_path, file_path) + shutil.rmtree(temp_dir) + os.remove(temp_zip_file_path) if check_sha1(file_path, sha1_hash): return file_path diff --git a/python/mxnet/gluon/utils.py b/python/mxnet/gluon/utils.py index 695c8a0f64e7..6521472bc41e 100644 --- a/python/mxnet/gluon/utils.py +++ b/python/mxnet/gluon/utils.py @@ -263,8 +263,7 @@ def replace_file(src, dst): ), src) -def download(url, path=None, overwrite=False, sha1_hash=None, retries=5, verify_ssl=True, - inplace=False): +def download(url, path=None, overwrite=False, sha1_hash=None, retries=5, verify_ssl=True): """Download a given URL Parameters @@ -283,9 +282,6 @@ def download(url, path=None, overwrite=False, sha1_hash=None, retries=5, verify_ The number of times to attempt the download in case of failure or non 200 return codes verify_ssl : bool, default True Verify SSL certificates. - inplace : bool, default False - Whether to write to the file at destination path inplace. Usually used if the path of temp - file is provided. Returns ------- @@ -323,30 +319,26 @@ def download(url, path=None, overwrite=False, sha1_hash=None, retries=5, verify_ r = requests.get(url, stream=True, verify=verify_ssl) if r.status_code != 200: raise RuntimeError('Failed downloading url {}'.format(url)) - if inplace: - temp_file_name = fname - else: - random_uuid = str(uuid.uuid4()) - temp_file_name = '{}.{}'.format(fname, random_uuid) - with open(temp_file_name, 'wb') as f: + # create uuid for temporary files + random_uuid = str(uuid.uuid4()) + with open('{}.{}'.format(fname, random_uuid), 'wb') as f: for chunk in r.iter_content(chunk_size=1024): if chunk: # filter out keep-alive new chunks f.write(chunk) - if not inplace: - # if the target file exists(created by other processes) - # and have the same hash with target file - # delete the temporary file - if not os.path.exists(fname) or (sha1_hash and not check_sha1(fname, sha1_hash)): - # atmoic operation in the same file system - replace_file('{}.{}'.format(fname, random_uuid), fname) - else: - try: - os.remove('{}.{}'.format(fname, random_uuid)) - except OSError: - pass - finally: - warnings.warn( - 'File {} exists in file system so the downloaded file is deleted'.format(fname)) + # if the target file exists(created by other processes) + # and have the same hash with target file + # delete the temporary file + if not os.path.exists(fname) or (sha1_hash and not check_sha1(fname, sha1_hash)): + # atmoic operation in the same file system + replace_file('{}.{}'.format(fname, random_uuid), fname) + else: + try: + os.remove('{}.{}'.format(fname, random_uuid)) + except OSError: + pass + finally: + warnings.warn( + 'File {} exists in file system so the downloaded file is deleted'.format(fname)) if sha1_hash and not check_sha1(fname, sha1_hash): raise UserWarning( 'File {} is downloaded but the content hash does not match.'