Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[BUGFIX] fix model zoo parallel download #17372

Merged
merged 6 commits into from
Jan 24, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 15 additions & 7 deletions python/mxnet/gluon/model_zoo/model_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,11 @@
import os
import zipfile
import logging
import tempfile
import uuid
import shutil

from ..utils import download, check_sha1
from ..utils import download, check_sha1, replace_file
from ... import base, util

_model_sha1 = {name: checksum for checksum, name in [
Expand Down Expand Up @@ -103,16 +106,21 @@ def get_model_file(name, root=os.path.join(base.data_dir(), 'models')):

util.makedirs(root)

zip_file_path = os.path.join(root, file_name+'.zip')
repo_url = os.environ.get('MXNET_GLUON_REPO', apache_repo_url)
if repo_url[-1] != '/':
repo_url = repo_url + '/'

random_uuid = str(uuid.uuid4())
temp_zip_file_path = os.path.join(root, file_name+'.zip'+random_uuid)
download(_url_format.format(repo_url=repo_url, file_name=file_name),
path=zip_file_path,
overwrite=True)
with zipfile.ZipFile(zip_file_path) as zf:
zf.extractall(root)
os.remove(zip_file_path)
path=temp_zip_file_path, overwrite=True)
with zipfile.ZipFile(temp_zip_file_path) as zf:
temp_dir = tempfile.mkdtemp(dir=root)
zf.extractall(temp_dir)
temp_file_path = os.path.join(temp_dir, file_name+'.params')
replace_file(temp_file_path, file_path)
shutil.rmtree(temp_dir)
os.remove(temp_zip_file_path)

if check_sha1(file_path, sha1_hash):
return file_path
Expand Down
30 changes: 21 additions & 9 deletions python/mxnet/gluon/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from __future__ import absolute_import

__all__ = ['split_data', 'split_and_load', 'clip_global_norm',
'check_sha1', 'download']
'check_sha1', 'download', 'replace_file']

import os
import sys
Expand All @@ -35,7 +35,7 @@
import numpy as np

from .. import ndarray
from ..util import is_np_shape, is_np_array
from ..util import is_np_shape, is_np_array, makedirs
from .. import numpy as _mx_np # pylint: disable=reimported


Expand Down Expand Up @@ -197,8 +197,14 @@ def check_sha1(filename, sha1_hash):

if not sys.platform.startswith('win32'):
# refer to https://github.com/untitaker/python-atomicwrites
def _replace_atomic(src, dst):
"""Implement atomic os.replace with linux and OSX. Internal use only"""
def replace_file(src, dst):
"""Implement atomic os.replace with linux and OSX.

Parameters
----------
src : source file path
dst : destination file path
"""
try:
os.rename(src, dst)
except OSError:
Expand Down Expand Up @@ -240,19 +246,25 @@ def _handle_errors(rv, src):
finally:
raise OSError(msg)

def _replace_atomic(src, dst):
def replace_file(src, dst):
"""Implement atomic os.replace with windows.

refer to https://docs.microsoft.com/en-us/windows/desktop/api/winbase/nf-winbase-movefileexw
The function fails when one of the process(copy, flush, delete) fails.
Internal use only"""

Parameters
----------
src : source file path
dst : destination file path
"""
_handle_errors(ctypes.windll.kernel32.MoveFileExW(
_str_to_unicode(src), _str_to_unicode(dst),
_windows_default_flags | _MOVEFILE_REPLACE_EXISTING
), src)


def download(url, path=None, overwrite=False, sha1_hash=None, retries=5, verify_ssl=True):
"""Download an given URL
"""Download a given URL

Parameters
----------
Expand Down Expand Up @@ -298,7 +310,7 @@ def download(url, path=None, overwrite=False, sha1_hash=None, retries=5, verify_
if overwrite or not os.path.exists(fname) or (sha1_hash and not check_sha1(fname, sha1_hash)):
dirname = os.path.dirname(os.path.abspath(os.path.expanduser(fname)))
if not os.path.exists(dirname):
os.makedirs(dirname)
makedirs(dirname)
while retries + 1 > 0:
# Disable pyling too broad Exception
# pylint: disable=W0703
Expand All @@ -318,7 +330,7 @@ def download(url, path=None, overwrite=False, sha1_hash=None, retries=5, verify_
# delete the temporary file
if not os.path.exists(fname) or (sha1_hash and not check_sha1(fname, sha1_hash)):
# atmoic operation in the same file system
_replace_atomic('{}.{}'.format(fname, random_uuid), fname)
replace_file('{}.{}'.format(fname, random_uuid), fname)
else:
try:
os.remove('{}.{}'.format(fname, random_uuid))
Expand Down
16 changes: 16 additions & 0 deletions tests/python/unittest/test_gluon_model_zoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from mxnet.gluon.model_zoo.vision import get_model
import sys
from common import setup_module, with_seed, teardown
import multiprocessing


def eprint(*args, **kwargs):
Expand Down Expand Up @@ -49,6 +50,21 @@ def test_models():
model.collect_params().initialize()
model(mx.nd.random.uniform(shape=data_shape)).wait_to_read()

def parallel_download(model_name):
model = get_model(model_name, pretrained=True, root='./parallel_download')
print(type(model))

@with_seed()
def test_parallel_download():
processes = []
name = 'mobilenetv2_0.25'
for _ in range(10):
p = multiprocessing.Process(target=parallel_download, args=(name,))
processes.append(p)
for p in processes:
p.start()
for p in processes:
p.join()

if __name__ == '__main__':
import nose
Expand Down