Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Remove fetching test files by ETag.
Browse files Browse the repository at this point in the history
Will add it as a separate PR as per review comments.

Signed-off-by: Acharya <[email protected]>
  • Loading branch information
Acharya committed Mar 14, 2018
1 parent a9b2f62 commit e5443ac
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 16 deletions.
4 changes: 2 additions & 2 deletions example/onnx/super_resolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
def import_onnx():
"""Import the onnx model into mxnet"""
model_url = 'https://s3.amazonaws.com/onnx-mxnet/examples/super_resolution.onnx'
download(model_url, 'super_resolution.onnx', version_tag='"7348c879d16c42bc77e24e270f663524"')
download(model_url, 'super_resolution.onnx')

LOGGER.info("Converting onnx format to mxnet's symbol and params...")
sym, params = onnx_mxnet.import_model('super_resolution.onnx')
Expand All @@ -46,7 +46,7 @@ def get_test_image():
# Load test image
input_image_dim = 224
img_url = 'https://s3.amazonaws.com/onnx-mxnet/examples/super_res_input.jpg'
download(img_url, 'super_res_input.jpg', version_tag='"02c90a7248e51316b11f7f39dd1b226d"')
download(img_url, 'super_res_input.jpg')
img = Image.open('super_res_input.jpg').resize((input_image_dim, input_image_dim))
img_ycbcr = img.convert("YCbCr")
img_y, img_cb, img_cr = img_ycbcr.split()
Expand Down
16 changes: 10 additions & 6 deletions python/mxnet/contrib/onnx/_import/import_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,19 @@
from .import_onnx import GraphProto

def import_model(model_file):
"""Imports the supplied ONNX model file into MXNet symbol and parameters.
:parameters model_file
"""Imports the ONNX model file passed as a parameter into MXNet symbol and parameters.
Parameters
----------
model_file : ONNX model file name
model_file : str
ONNX model file name
:returns (sym, params)
Returns
-------
sym : mx.symbol
Compatible mxnet symbol
Mxnet symbol and parameter objects.
sym : mxnet.symbol
Mxnet symbol
params : dict of str to mx.ndarray
Dict of converted parameters stored in mx.ndarray format
"""
Expand Down
10 changes: 2 additions & 8 deletions python/mxnet/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1367,8 +1367,7 @@ def list_gpus():
pass
return range(len([i for i in re.split('\n') if 'GPU' in i]))


def download(url, fname=None, dirname=None, overwrite=False, version_tag=None):
def download(url, fname=None, dirname=None, overwrite=False):
"""Download an given URL
Parameters
Expand All @@ -1386,8 +1385,6 @@ def download(url, fname=None, dirname=None, overwrite=False, version_tag=None):
Default is false, which means skipping download if the local file
exists. If true, then download the url to overwrite the local file if
exists.
version_tag : str, optional
the version tag of the file.
Returns
-------
Expand All @@ -1410,15 +1407,12 @@ def download(url, fname=None, dirname=None, overwrite=False, version_tag=None):
if exc.errno != errno.EEXIST:
raise OSError('failed to create ' + dirname)

if not overwrite and os.path.exists(fname) and not version_tag:
if not overwrite and os.path.exists(fname):
logging.info("%s exists, skipping download", fname)
return fname

r = requests.get(url, stream=True)
assert r.status_code == 200, "failed to open %s" % url
if version_tag and r.headers['ETag'] != version_tag:
logging.info("The version tag of the file does not match the expected version. "
+ "Proceeding with the file download...")
with open(fname, 'wb') as f:
for chunk in r.iter_content(chunk_size=1024):
if chunk: # filter out keep-alive new chunks
Expand Down

0 comments on commit e5443ac

Please sign in to comment.