diff --git a/python/mxnet/gluon/model_zoo/model_store.py b/python/mxnet/gluon/model_zoo/model_store.py index 11ac47bae905..6da7dd106556 100644 --- a/python/mxnet/gluon/model_zoo/model_store.py +++ b/python/mxnet/gluon/model_zoo/model_store.py @@ -22,8 +22,11 @@ import os import zipfile import logging +import tempfile +import uuid +import shutil -from ..utils import download, check_sha1 +from ..utils import download, check_sha1, replace_file from ... import base, util _model_sha1 = {name: checksum for checksum, name in [ @@ -103,16 +106,21 @@ def get_model_file(name, root=os.path.join(base.data_dir(), 'models')): util.makedirs(root) - zip_file_path = os.path.join(root, file_name+'.zip') repo_url = os.environ.get('MXNET_GLUON_REPO', apache_repo_url) if repo_url[-1] != '/': repo_url = repo_url + '/' + + random_uuid = str(uuid.uuid4()) + temp_zip_file_path = os.path.join(root, file_name+'.zip'+random_uuid) download(_url_format.format(repo_url=repo_url, file_name=file_name), - path=zip_file_path, - overwrite=True) - with zipfile.ZipFile(zip_file_path) as zf: - zf.extractall(root) - os.remove(zip_file_path) + path=temp_zip_file_path, overwrite=True) + with zipfile.ZipFile(temp_zip_file_path) as zf: + temp_dir = tempfile.mkdtemp(dir=root) + zf.extractall(temp_dir) + temp_file_path = os.path.join(temp_dir, file_name+'.params') + replace_file(temp_file_path, file_path) + shutil.rmtree(temp_dir) + os.remove(temp_zip_file_path) if check_sha1(file_path, sha1_hash): return file_path diff --git a/python/mxnet/gluon/utils.py b/python/mxnet/gluon/utils.py index 83ed15aed450..6521472bc41e 100644 --- a/python/mxnet/gluon/utils.py +++ b/python/mxnet/gluon/utils.py @@ -21,7 +21,7 @@ from __future__ import absolute_import __all__ = ['split_data', 'split_and_load', 'clip_global_norm', - 'check_sha1', 'download'] + 'check_sha1', 'download', 'replace_file'] import os import sys @@ -35,7 +35,7 @@ import numpy as np from .. import ndarray -from ..util import is_np_shape, is_np_array +from ..util import is_np_shape, is_np_array, makedirs from .. import numpy as _mx_np # pylint: disable=reimported @@ -197,8 +197,14 @@ def check_sha1(filename, sha1_hash): if not sys.platform.startswith('win32'): # refer to https://github.com/untitaker/python-atomicwrites - def _replace_atomic(src, dst): - """Implement atomic os.replace with linux and OSX. Internal use only""" + def replace_file(src, dst): + """Implement atomic os.replace with linux and OSX. + + Parameters + ---------- + src : source file path + dst : destination file path + """ try: os.rename(src, dst) except OSError: @@ -240,11 +246,17 @@ def _handle_errors(rv, src): finally: raise OSError(msg) - def _replace_atomic(src, dst): + def replace_file(src, dst): """Implement atomic os.replace with windows. + refer to https://docs.microsoft.com/en-us/windows/desktop/api/winbase/nf-winbase-movefileexw The function fails when one of the process(copy, flush, delete) fails. - Internal use only""" + + Parameters + ---------- + src : source file path + dst : destination file path + """ _handle_errors(ctypes.windll.kernel32.MoveFileExW( _str_to_unicode(src), _str_to_unicode(dst), _windows_default_flags | _MOVEFILE_REPLACE_EXISTING @@ -252,7 +264,7 @@ def _replace_atomic(src, dst): def download(url, path=None, overwrite=False, sha1_hash=None, retries=5, verify_ssl=True): - """Download an given URL + """Download a given URL Parameters ---------- @@ -298,7 +310,7 @@ def download(url, path=None, overwrite=False, sha1_hash=None, retries=5, verify_ if overwrite or not os.path.exists(fname) or (sha1_hash and not check_sha1(fname, sha1_hash)): dirname = os.path.dirname(os.path.abspath(os.path.expanduser(fname))) if not os.path.exists(dirname): - os.makedirs(dirname) + makedirs(dirname) while retries + 1 > 0: # Disable pyling too broad Exception # pylint: disable=W0703 @@ -318,7 +330,7 @@ def download(url, path=None, overwrite=False, sha1_hash=None, retries=5, verify_ # delete the temporary file if not os.path.exists(fname) or (sha1_hash and not check_sha1(fname, sha1_hash)): # atmoic operation in the same file system - _replace_atomic('{}.{}'.format(fname, random_uuid), fname) + replace_file('{}.{}'.format(fname, random_uuid), fname) else: try: os.remove('{}.{}'.format(fname, random_uuid)) diff --git a/tests/python/unittest/test_gluon_model_zoo.py b/tests/python/unittest/test_gluon_model_zoo.py index a64668451a25..d53dd403a5b8 100644 --- a/tests/python/unittest/test_gluon_model_zoo.py +++ b/tests/python/unittest/test_gluon_model_zoo.py @@ -20,6 +20,7 @@ from mxnet.gluon.model_zoo.vision import get_model import sys from common import setup_module, with_seed, teardown +import multiprocessing def eprint(*args, **kwargs): @@ -49,6 +50,21 @@ def test_models(): model.collect_params().initialize() model(mx.nd.random.uniform(shape=data_shape)).wait_to_read() +def parallel_download(model_name): + model = get_model(model_name, pretrained=True, root='./parallel_download') + print(type(model)) + +@with_seed() +def test_parallel_download(): + processes = [] + name = 'mobilenetv2_0.25' + for _ in range(10): + p = multiprocessing.Process(target=parallel_download, args=(name,)) + processes.append(p) + for p in processes: + p.start() + for p in processes: + p.join() if __name__ == '__main__': import nose