From addd518ded80855a956a79ea4717e29194849e4a Mon Sep 17 00:00:00 2001 From: Pedro Larroy Date: Tue, 10 Jul 2018 22:15:55 +0200 Subject: [PATCH] set MXNET_DATA_DIR as base for downloaded models through base.data_dir() --- docs/faq/env_var.md | 5 ++ python/mxnet/base.py | 27 +++++++++++ python/mxnet/contrib/text/embedding.py | 9 ++-- python/mxnet/gluon/contrib/data/text.py | 11 ++--- python/mxnet/gluon/data/vision/datasets.py | 18 +++---- python/mxnet/gluon/model_zoo/model_store.py | 17 +++---- .../mxnet/gluon/model_zoo/vision/__init__.py | 2 +- .../mxnet/gluon/model_zoo/vision/alexnet.py | 5 +- .../mxnet/gluon/model_zoo/vision/densenet.py | 13 ++--- .../mxnet/gluon/model_zoo/vision/inception.py | 5 +- .../mxnet/gluon/model_zoo/vision/mobilenet.py | 9 ++-- python/mxnet/gluon/model_zoo/vision/resnet.py | 25 +++++----- .../gluon/model_zoo/vision/squeezenet.py | 9 ++-- python/mxnet/gluon/model_zoo/vision/vgg.py | 21 ++++---- python/mxnet/test_utils.py | 2 +- tests/python/unittest/test_base.py | 48 +++++++++++++++++++ 16 files changed, 157 insertions(+), 69 deletions(-) create mode 100644 tests/python/unittest/test_base.py diff --git a/docs/faq/env_var.md b/docs/faq/env_var.md index 12a898aadc24..8d01b84e17e3 100644 --- a/docs/faq/env_var.md +++ b/docs/faq/env_var.md @@ -126,6 +126,11 @@ When USE_PROFILER is enabled in Makefile or CMake, the following environments ca - Values: String ```(default='https://apache-mxnet.s3-accelerate.dualstack.amazonaws.com/'``` - The repository url to be used for Gluon datasets and pre-trained models. +* MXNET_DATA_DIR + - Data directory in the filesystem for storage, for example when downloading gluon models. + - Default in *nix is .local/share/mxnet APPDATA/mxnet in windows, XDG_DATA_HOME/mxnet if + XDG_DATA_HOME is set. + Settings for Minimum Memory Usage --------------------------------- - Make sure ```min(MXNET_EXEC_NUM_TEMP, MXNET_GPU_WORKER_NTHREADS) = 1``` diff --git a/python/mxnet/base.py b/python/mxnet/base.py index 0fb73b3c7dda..ea6d145bea3f 100644 --- a/python/mxnet/base.py +++ b/python/mxnet/base.py @@ -27,6 +27,7 @@ import warnings import inspect import numpy as np +import platform from . import libinfo warnings.filterwarnings('default', category=DeprecationWarning) @@ -47,6 +48,32 @@ integer_types = (int, long, np.int32, np.int64) py_str = lambda x: x + +def data_dir_default(): + """ + + :return: default data directory depending on the platform and environment variables + """ + system = platform.system() + if system == 'Windows': + return os.environ.get('APPDATA') + elif system == 'Linux' or system == 'Darwin': + if 'XDG_DATA_HOME'in os.environ: + return os.path.join(os.environ['XDG_DATA_HOME'], 'mxnet') + else: + return os.path.join(os.path.expanduser("~"), '.local', 'share', 'mxnet') + else: + return os.path.join(os.path.expanduser("~"), '.mxnet') + + +def data_dir(): + """ + + :return: data directory in the filesystem for storage, for example when downloading models + """ + return os.getenv('MXNET_DATA_DIR', data_dir_default()) + + class _NullType(object): """Placeholder for arguments""" def __repr__(self): diff --git a/python/mxnet/contrib/text/embedding.py b/python/mxnet/contrib/text/embedding.py index 6598718e6b01..9b12746db57e 100644 --- a/python/mxnet/contrib/text/embedding.py +++ b/python/mxnet/contrib/text/embedding.py @@ -34,6 +34,7 @@ from . import vocab from ... import ndarray as nd from ... import registry +from ... import base def register(embedding_cls): @@ -496,7 +497,7 @@ class GloVe(_TokenEmbedding): ---------- pretrained_file_name : str, default 'glove.840B.300d.txt' The name of the pre-trained token embedding file. - embedding_root : str, default os.path.join('~', '.mxnet', 'embeddings') + embedding_root : str, default $MXNET_DATA_DIR/embeddings The root directory for storing embedding-related files. init_unknown_vec : callback The callback used to initialize the embedding vector for the unknown token. @@ -541,7 +542,7 @@ def _get_download_file_name(cls, pretrained_file_name): return archive def __init__(self, pretrained_file_name='glove.840B.300d.txt', - embedding_root=os.path.join('~', '.mxnet', 'embeddings'), + embedding_root=os.path.join(base.data_dir(),'embeddings'), init_unknown_vec=nd.zeros, vocabulary=None, **kwargs): GloVe._check_pretrained_file_names(pretrained_file_name) @@ -600,7 +601,7 @@ class FastText(_TokenEmbedding): ---------- pretrained_file_name : str, default 'wiki.en.vec' The name of the pre-trained token embedding file. - embedding_root : str, default os.path.join('~', '.mxnet', 'embeddings') + embedding_root : str, default $MXNET_DATA_DIR/embeddings The root directory for storing embedding-related files. init_unknown_vec : callback The callback used to initialize the embedding vector for the unknown token. @@ -642,7 +643,7 @@ def _get_download_file_name(cls, pretrained_file_name): return '.'.join(pretrained_file_name.split('.')[:-1])+'.zip' def __init__(self, pretrained_file_name='wiki.simple.vec', - embedding_root=os.path.join('~', '.mxnet', 'embeddings'), + embedding_root=os.path.join(base.data_dir(), 'embeddings'), init_unknown_vec=nd.zeros, vocabulary=None, **kwargs): FastText._check_pretrained_file_names(pretrained_file_name) diff --git a/python/mxnet/gluon/contrib/data/text.py b/python/mxnet/gluon/contrib/data/text.py index 98fe6b657f2b..15ac6c764b81 100644 --- a/python/mxnet/gluon/contrib/data/text.py +++ b/python/mxnet/gluon/contrib/data/text.py @@ -30,8 +30,7 @@ from ...data import dataset from ...utils import download, check_sha1, _get_repo_file_url from ....contrib import text -from .... import nd - +from .... import nd, base class _LanguageModelDataset(dataset._DownloadedDataset): # pylint: disable=abstract-method def __init__(self, root, namespace, vocabulary): @@ -116,7 +115,7 @@ class WikiText2(_WikiText): Parameters ---------- - root : str, default '~/.mxnet/datasets/wikitext-2' + root : str, default $MXNET_DATA_DIR/datasets/wikitext-2 Path to temp folder for storing data. segment : str, default 'train' Dataset segment. Options are 'train', 'validation', 'test'. @@ -127,7 +126,7 @@ class WikiText2(_WikiText): The sequence length of each sample, regardless of the sentence boundary. """ - def __init__(self, root=os.path.join('~', '.mxnet', 'datasets', 'wikitext-2'), + def __init__(self, root=os.path.join(base.data_dir(), 'datasets', 'wikitext-2'), segment='train', vocab=None, seq_len=35): self._archive_file = ('wikitext-2-v1.zip', '3c914d17d80b1459be871a5039ac23e752a53cbe') self._data_file = {'train': ('wiki.train.tokens', @@ -154,7 +153,7 @@ class WikiText103(_WikiText): Parameters ---------- - root : str, default '~/.mxnet/datasets/wikitext-103' + root : str, default $MXNET_DATA_DIR/datasets/wikitext-103 Path to temp folder for storing data. segment : str, default 'train' Dataset segment. Options are 'train', 'validation', 'test'. @@ -164,7 +163,7 @@ class WikiText103(_WikiText): seq_len : int, default 35 The sequence length of each sample, regardless of the sentence boundary. """ - def __init__(self, root=os.path.join('~', '.mxnet', 'datasets', 'wikitext-103'), + def __init__(self, root=os.path.join(base.data_dir(), 'datasets', 'wikitext-103'), segment='train', vocab=None, seq_len=35): self._archive_file = ('wikitext-103-v1.zip', '0aec09a7537b58d4bb65362fee27650eeaba625a') self._data_file = {'train': ('wiki.train.tokens', diff --git a/python/mxnet/gluon/data/vision/datasets.py b/python/mxnet/gluon/data/vision/datasets.py index 74a5aebf17bb..02c971f9128b 100644 --- a/python/mxnet/gluon/data/vision/datasets.py +++ b/python/mxnet/gluon/data/vision/datasets.py @@ -30,7 +30,7 @@ from .. import dataset from ...utils import download, check_sha1, _get_repo_file_url -from .... import nd, image, recordio +from .... import nd, image, recordio, base class MNIST(dataset._DownloadedDataset): @@ -40,7 +40,7 @@ class MNIST(dataset._DownloadedDataset): Parameters ---------- - root : str, default '~/.mxnet/datasets/mnist' + root : str, default $MXNET_DATA_DIR/datasets/mnist Path to temp folder for storing data. train : bool, default True Whether to load the training or testing set. @@ -51,7 +51,7 @@ class MNIST(dataset._DownloadedDataset): transform=lambda data, label: (data.astype(np.float32)/255, label) """ - def __init__(self, root=os.path.join('~', '.mxnet', 'datasets', 'mnist'), + def __init__(self, root=os.path.join(base.data_dir(), 'datasets', 'mnist'), train=True, transform=None): self._train = train self._train_data = ('train-images-idx3-ubyte.gz', @@ -101,7 +101,7 @@ class FashionMNIST(MNIST): Parameters ---------- - root : str, default '~/.mxnet/datasets/fashion-mnist' + root : str, default $MXNET_DATA_DIR/datasets/fashion-mnist' Path to temp folder for storing data. train : bool, default True Whether to load the training or testing set. @@ -112,7 +112,7 @@ class FashionMNIST(MNIST): transform=lambda data, label: (data.astype(np.float32)/255, label) """ - def __init__(self, root=os.path.join('~', '.mxnet', 'datasets', 'fashion-mnist'), + def __init__(self, root=os.path.join(base.data_dir(), 'datasets', 'fashion-mnist'), train=True, transform=None): self._train = train self._train_data = ('train-images-idx3-ubyte.gz', @@ -134,7 +134,7 @@ class CIFAR10(dataset._DownloadedDataset): Parameters ---------- - root : str, default '~/.mxnet/datasets/cifar10' + root : str, default $MXNET_DATA_DIR/datasets/cifar10 Path to temp folder for storing data. train : bool, default True Whether to load the training or testing set. @@ -145,7 +145,7 @@ class CIFAR10(dataset._DownloadedDataset): transform=lambda data, label: (data.astype(np.float32)/255, label) """ - def __init__(self, root=os.path.join('~', '.mxnet', 'datasets', 'cifar10'), + def __init__(self, root=os.path.join(base.data_dir(), 'datasets', 'cifar10'), train=True, transform=None): self._train = train self._archive_file = ('cifar-10-binary.tar.gz', 'fab780a1e191a7eda0f345501ccd62d20f7ed891') @@ -197,7 +197,7 @@ class CIFAR100(CIFAR10): Parameters ---------- - root : str, default '~/.mxnet/datasets/cifar100' + root : str, default $MXNET_DATA_DIR/datasets/cifar100 Path to temp folder for storing data. fine_label : bool, default False Whether to load the fine-grained (100 classes) or coarse-grained (20 super-classes) labels. @@ -210,7 +210,7 @@ class CIFAR100(CIFAR10): transform=lambda data, label: (data.astype(np.float32)/255, label) """ - def __init__(self, root=os.path.join('~', '.mxnet', 'datasets', 'cifar100'), + def __init__(self, root=os.path.join(base.data_dir(), 'datasets', 'cifar100'), fine_label=False, train=True, transform=None): self._train = train self._archive_file = ('cifar-100-binary.tar.gz', 'a0bb982c76b83111308126cc779a992fa506b90b') diff --git a/python/mxnet/gluon/model_zoo/model_store.py b/python/mxnet/gluon/model_zoo/model_store.py index fb66a713685c..c5f5434fe351 100644 --- a/python/mxnet/gluon/model_zoo/model_store.py +++ b/python/mxnet/gluon/model_zoo/model_store.py @@ -21,8 +21,10 @@ __all__ = ['get_model_file', 'purge'] import os import zipfile +import logging from ..utils import download, check_sha1 +from ... import base _model_sha1 = {name: checksum for checksum, name in [ ('44335d1f0046b328243b32a26a4fbd62d9057b45', 'alexnet'), @@ -68,7 +70,7 @@ def short_hash(name): raise ValueError('Pretrained model for {name} is not available.'.format(name=name)) return _model_sha1[name][:8] -def get_model_file(name, root=os.path.join('~', '.mxnet', 'models')): +def get_model_file(name, root=os.path.join(base.data_dir(), 'models')): r"""Return location for the pretrained on local file system. This function will download from online model zoo when model cannot be found or has mismatch. @@ -78,7 +80,7 @@ def get_model_file(name, root=os.path.join('~', '.mxnet', 'models')): ---------- name : str Name of the model. - root : str, default '~/.mxnet/models' + root : str, default $MXNET_DATA_DIR/models Location for keeping the model parameters. Returns @@ -95,12 +97,11 @@ def get_model_file(name, root=os.path.join('~', '.mxnet', 'models')): if check_sha1(file_path, sha1_hash): return file_path else: - print('Mismatch in the content of model file detected. Downloading again.') + logging.warn('Mismatch in the content of model file detected. Downloading again.') else: - print('Model file is not found. Downloading.') + logging.info('Model file not found. Downloading to %s.', file_path) - if not os.path.exists(root): - os.makedirs(root) + os.makedirs(root, exist_ok=True) zip_file_path = os.path.join(root, file_name+'.zip') repo_url = os.environ.get('MXNET_GLUON_REPO', apache_repo_url) @@ -118,12 +119,12 @@ def get_model_file(name, root=os.path.join('~', '.mxnet', 'models')): else: raise ValueError('Downloaded file has different hash. Please try again.') -def purge(root=os.path.join('~', '.mxnet', 'models')): +def purge(root=os.path.join(base.data_dir(), 'models')): r"""Purge all pretrained model files in local file store. Parameters ---------- - root : str, default '~/.mxnet/models' + root : str, default '$MXNET_DATA_DIR/models' Location for keeping the model parameters. """ root = os.path.expanduser(root) diff --git a/python/mxnet/gluon/model_zoo/vision/__init__.py b/python/mxnet/gluon/model_zoo/vision/__init__.py index a6e5dc137d48..798125eef955 100644 --- a/python/mxnet/gluon/model_zoo/vision/__init__.py +++ b/python/mxnet/gluon/model_zoo/vision/__init__.py @@ -101,7 +101,7 @@ def get_model(name, **kwargs): Number of classes for the output layer. ctx : Context, default CPU The context in which to load the pretrained weights. - root : str, default '~/.mxnet/models' + root : str, default '$MXNET_DATA_DIR/models' Location for keeping the model parameters. Returns diff --git a/python/mxnet/gluon/model_zoo/vision/alexnet.py b/python/mxnet/gluon/model_zoo/vision/alexnet.py index fdb006258c2a..35d5c239490b 100644 --- a/python/mxnet/gluon/model_zoo/vision/alexnet.py +++ b/python/mxnet/gluon/model_zoo/vision/alexnet.py @@ -25,6 +25,7 @@ from ....context import cpu from ...block import HybridBlock from ... import nn +from .... import base # Net class AlexNet(HybridBlock): @@ -68,7 +69,7 @@ def hybrid_forward(self, F, x): # Constructor def alexnet(pretrained=False, ctx=cpu(), - root=os.path.join('~', '.mxnet', 'models'), **kwargs): + root=os.path.join(base.data_dir(), 'models'), **kwargs): r"""AlexNet model from the `"One weird trick..." `_ paper. Parameters @@ -77,7 +78,7 @@ def alexnet(pretrained=False, ctx=cpu(), Whether to load the pretrained weights for model. ctx : Context, default CPU The context in which to load the pretrained weights. - root : str, default '~/.mxnet/models' + root : str, default $MXNET_DATA_DIR/models Location for keeping the model parameters. """ net = AlexNet(**kwargs) diff --git a/python/mxnet/gluon/model_zoo/vision/densenet.py b/python/mxnet/gluon/model_zoo/vision/densenet.py index b03f5ce8d52a..ae0018e5fd2d 100644 --- a/python/mxnet/gluon/model_zoo/vision/densenet.py +++ b/python/mxnet/gluon/model_zoo/vision/densenet.py @@ -26,6 +26,7 @@ from ...block import HybridBlock from ... import nn from ...contrib.nn import HybridConcurrent, Identity +from .... import base # Helpers def _make_dense_block(num_layers, bn_size, growth_rate, dropout, stage_index): @@ -122,7 +123,7 @@ def hybrid_forward(self, F, x): # Constructor def get_densenet(num_layers, pretrained=False, ctx=cpu(), - root=os.path.join('~', '.mxnet', 'models'), **kwargs): + root=os.path.join(base.data_dir(), 'models'), **kwargs): r"""Densenet-BC model from the `"Densely Connected Convolutional Networks" `_ paper. @@ -134,7 +135,7 @@ def get_densenet(num_layers, pretrained=False, ctx=cpu(), Whether to load the pretrained weights for model. ctx : Context, default CPU The context in which to load the pretrained weights. - root : str, default '~/.mxnet/models' + root : str, default $MXNET_DATA_DIR/models Location for keeping the model parameters. """ num_init_features, growth_rate, block_config = densenet_spec[num_layers] @@ -154,7 +155,7 @@ def densenet121(**kwargs): Whether to load the pretrained weights for model. ctx : Context, default CPU The context in which to load the pretrained weights. - root : str, default '~/.mxnet/models' + root : str, default '$MXNET_DATA_DIR/models' Location for keeping the model parameters. """ return get_densenet(121, **kwargs) @@ -169,7 +170,7 @@ def densenet161(**kwargs): Whether to load the pretrained weights for model. ctx : Context, default CPU The context in which to load the pretrained weights. - root : str, default '~/.mxnet/models' + root : str, default '$MXNET_DATA_DIR/models' Location for keeping the model parameters. """ return get_densenet(161, **kwargs) @@ -184,7 +185,7 @@ def densenet169(**kwargs): Whether to load the pretrained weights for model. ctx : Context, default CPU The context in which to load the pretrained weights. - root : str, default '~/.mxnet/models' + root : str, default '$MXNET_DATA_DIR/models' Location for keeping the model parameters. """ return get_densenet(169, **kwargs) @@ -199,7 +200,7 @@ def densenet201(**kwargs): Whether to load the pretrained weights for model. ctx : Context, default CPU The context in which to load the pretrained weights. - root : str, default '~/.mxnet/models' + root : str, default '$MXNET_DATA_DIR/models' Location for keeping the model parameters. """ return get_densenet(201, **kwargs) diff --git a/python/mxnet/gluon/model_zoo/vision/inception.py b/python/mxnet/gluon/model_zoo/vision/inception.py index 7c54691f1b59..0693bb559deb 100644 --- a/python/mxnet/gluon/model_zoo/vision/inception.py +++ b/python/mxnet/gluon/model_zoo/vision/inception.py @@ -26,6 +26,7 @@ from ...block import HybridBlock from ... import nn from ...contrib.nn import HybridConcurrent +from .... import base # Helpers def _make_basic_conv(**kwargs): @@ -199,7 +200,7 @@ def hybrid_forward(self, F, x): # Constructor def inception_v3(pretrained=False, ctx=cpu(), - root=os.path.join('~', '.mxnet', 'models'), **kwargs): + root=os.path.join(base.data_dir(), 'models'), **kwargs): r"""Inception v3 model from `"Rethinking the Inception Architecture for Computer Vision" `_ paper. @@ -210,7 +211,7 @@ def inception_v3(pretrained=False, ctx=cpu(), Whether to load the pretrained weights for model. ctx : Context, default CPU The context in which to load the pretrained weights. - root : str, default '~/.mxnet/models' + root : str, default $MXNET_DATA_DIR/models Location for keeping the model parameters. """ net = Inception3(**kwargs) diff --git a/python/mxnet/gluon/model_zoo/vision/mobilenet.py b/python/mxnet/gluon/model_zoo/vision/mobilenet.py index 1a2c9b946190..6e89365b0cb6 100644 --- a/python/mxnet/gluon/model_zoo/vision/mobilenet.py +++ b/python/mxnet/gluon/model_zoo/vision/mobilenet.py @@ -30,6 +30,7 @@ from ... import nn from ....context import cpu from ...block import HybridBlock +from .... import base # Helpers @@ -188,7 +189,7 @@ def hybrid_forward(self, F, x): # Constructor def get_mobilenet(multiplier, pretrained=False, ctx=cpu(), - root=os.path.join('~', '.mxnet', 'models'), **kwargs): + root=os.path.join(base.data_dir(), 'models'), **kwargs): r"""MobileNet model from the `"MobileNets: Efficient Convolutional Neural Networks for Mobile Vision Applications" `_ paper. @@ -203,7 +204,7 @@ def get_mobilenet(multiplier, pretrained=False, ctx=cpu(), Whether to load the pretrained weights for model. ctx : Context, default CPU The context in which to load the pretrained weights. - root : str, default '~/.mxnet/models' + root : str, default $MXNET_DATA_DIR/models Location for keeping the model parameters. """ net = MobileNet(multiplier, **kwargs) @@ -219,7 +220,7 @@ def get_mobilenet(multiplier, pretrained=False, ctx=cpu(), def get_mobilenet_v2(multiplier, pretrained=False, ctx=cpu(), - root=os.path.join('~', '.mxnet', 'models'), **kwargs): + root=os.path.join(base.data_dir(), 'models'), **kwargs): r"""MobileNetV2 model from the `"Inverted Residuals and Linear Bottlenecks: Mobile Networks for Classification, Detection and Segmentation" @@ -235,7 +236,7 @@ def get_mobilenet_v2(multiplier, pretrained=False, ctx=cpu(), Whether to load the pretrained weights for model. ctx : Context, default CPU The context in which to load the pretrained weights. - root : str, default '~/.mxnet/models' + root : str, default $MXNET_DATA_DIR/models Location for keeping the model parameters. """ net = MobileNetV2(multiplier, **kwargs) diff --git a/python/mxnet/gluon/model_zoo/vision/resnet.py b/python/mxnet/gluon/model_zoo/vision/resnet.py index da279b89583e..4aa101370dc0 100644 --- a/python/mxnet/gluon/model_zoo/vision/resnet.py +++ b/python/mxnet/gluon/model_zoo/vision/resnet.py @@ -32,6 +32,7 @@ from ....context import cpu from ...block import HybridBlock from ... import nn +from .... import base # Helpers def _conv3x3(channels, stride, in_channels): @@ -356,7 +357,7 @@ def hybrid_forward(self, F, x): # Constructor def get_resnet(version, num_layers, pretrained=False, ctx=cpu(), - root=os.path.join('~', '.mxnet', 'models'), **kwargs): + root=os.path.join(base.data_dir(), 'models'), **kwargs): r"""ResNet V1 model from `"Deep Residual Learning for Image Recognition" `_ paper. ResNet V2 model from `"Identity Mappings in Deep Residual Networks" @@ -372,7 +373,7 @@ def get_resnet(version, num_layers, pretrained=False, ctx=cpu(), Whether to load the pretrained weights for model. ctx : Context, default CPU The context in which to load the pretrained weights. - root : str, default '~/.mxnet/models' + root : str, default $MXNET_DATA_DIR/models Location for keeping the model parameters. """ assert num_layers in resnet_spec, \ @@ -400,7 +401,7 @@ def resnet18_v1(**kwargs): Whether to load the pretrained weights for model. ctx : Context, default CPU The context in which to load the pretrained weights. - root : str, default '~/.mxnet/models' + root : str, default '$MXNET_DATA_DIR/models' Location for keeping the model parameters. """ return get_resnet(1, 18, **kwargs) @@ -415,7 +416,7 @@ def resnet34_v1(**kwargs): Whether to load the pretrained weights for model. ctx : Context, default CPU The context in which to load the pretrained weights. - root : str, default '~/.mxnet/models' + root : str, default '$MXNET_DATA_DIR/models' Location for keeping the model parameters. """ return get_resnet(1, 34, **kwargs) @@ -430,7 +431,7 @@ def resnet50_v1(**kwargs): Whether to load the pretrained weights for model. ctx : Context, default CPU The context in which to load the pretrained weights. - root : str, default '~/.mxnet/models' + root : str, default '$MXNET_DATA_DIR/models' Location for keeping the model parameters. """ return get_resnet(1, 50, **kwargs) @@ -445,7 +446,7 @@ def resnet101_v1(**kwargs): Whether to load the pretrained weights for model. ctx : Context, default CPU The context in which to load the pretrained weights. - root : str, default '~/.mxnet/models' + root : str, default '$MXNET_DATA_DIR/models' Location for keeping the model parameters. """ return get_resnet(1, 101, **kwargs) @@ -460,7 +461,7 @@ def resnet152_v1(**kwargs): Whether to load the pretrained weights for model. ctx : Context, default CPU The context in which to load the pretrained weights. - root : str, default '~/.mxnet/models' + root : str, default '$MXNET_DATA_DIR/models' Location for keeping the model parameters. """ return get_resnet(1, 152, **kwargs) @@ -475,7 +476,7 @@ def resnet18_v2(**kwargs): Whether to load the pretrained weights for model. ctx : Context, default CPU The context in which to load the pretrained weights. - root : str, default '~/.mxnet/models' + root : str, default '$MXNET_DATA_DIR/models' Location for keeping the model parameters. """ return get_resnet(2, 18, **kwargs) @@ -490,7 +491,7 @@ def resnet34_v2(**kwargs): Whether to load the pretrained weights for model. ctx : Context, default CPU The context in which to load the pretrained weights. - root : str, default '~/.mxnet/models' + root : str, default '$MXNET_DATA_DIR/models' Location for keeping the model parameters. """ return get_resnet(2, 34, **kwargs) @@ -505,7 +506,7 @@ def resnet50_v2(**kwargs): Whether to load the pretrained weights for model. ctx : Context, default CPU The context in which to load the pretrained weights. - root : str, default '~/.mxnet/models' + root : str, default '$MXNET_DATA_DIR/models' Location for keeping the model parameters. """ return get_resnet(2, 50, **kwargs) @@ -520,7 +521,7 @@ def resnet101_v2(**kwargs): Whether to load the pretrained weights for model. ctx : Context, default CPU The context in which to load the pretrained weights. - root : str, default '~/.mxnet/models' + root : str, default '$MXNET_DATA_DIR/models' Location for keeping the model parameters. """ return get_resnet(2, 101, **kwargs) @@ -535,7 +536,7 @@ def resnet152_v2(**kwargs): Whether to load the pretrained weights for model. ctx : Context, default CPU The context in which to load the pretrained weights. - root : str, default '~/.mxnet/models' + root : str, default '$MXNET_DATA_DIR/models' Location for keeping the model parameters. """ return get_resnet(2, 152, **kwargs) diff --git a/python/mxnet/gluon/model_zoo/vision/squeezenet.py b/python/mxnet/gluon/model_zoo/vision/squeezenet.py index aaff4c36dfa0..182ac5b64d16 100644 --- a/python/mxnet/gluon/model_zoo/vision/squeezenet.py +++ b/python/mxnet/gluon/model_zoo/vision/squeezenet.py @@ -26,6 +26,7 @@ from ...block import HybridBlock from ... import nn from ...contrib.nn import HybridConcurrent +from .... import base # Helpers def _make_fire(squeeze_channels, expand1x1_channels, expand3x3_channels): @@ -110,7 +111,7 @@ def hybrid_forward(self, F, x): # Constructor def get_squeezenet(version, pretrained=False, ctx=cpu(), - root=os.path.join('~', '.mxnet', 'models'), **kwargs): + root=os.path.join(base.data_dir(), 'models'), **kwargs): r"""SqueezeNet model from the `"SqueezeNet: AlexNet-level accuracy with 50x fewer parameters and <0.5MB model size" `_ paper. SqueezeNet 1.1 model from the `official SqueezeNet repo @@ -126,7 +127,7 @@ def get_squeezenet(version, pretrained=False, ctx=cpu(), Whether to load the pretrained weights for model. ctx : Context, default CPU The context in which to load the pretrained weights. - root : str, default '~/.mxnet/models' + root : str, default $MXNET_DATA_DIR/models Location for keeping the model parameters. """ net = SqueezeNet(version, **kwargs) @@ -145,7 +146,7 @@ def squeezenet1_0(**kwargs): Whether to load the pretrained weights for model. ctx : Context, default CPU The context in which to load the pretrained weights. - root : str, default '~/.mxnet/models' + root : str, default '$MXNET_DATA_DIR/models' Location for keeping the model parameters. """ return get_squeezenet('1.0', **kwargs) @@ -162,7 +163,7 @@ def squeezenet1_1(**kwargs): Whether to load the pretrained weights for model. ctx : Context, default CPU The context in which to load the pretrained weights. - root : str, default '~/.mxnet/models' + root : str, default '$MXNET_DATA_DIR/models' Location for keeping the model parameters. """ return get_squeezenet('1.1', **kwargs) diff --git a/python/mxnet/gluon/model_zoo/vision/vgg.py b/python/mxnet/gluon/model_zoo/vision/vgg.py index a3b1685b4130..afa9f78506b5 100644 --- a/python/mxnet/gluon/model_zoo/vision/vgg.py +++ b/python/mxnet/gluon/model_zoo/vision/vgg.py @@ -30,6 +30,7 @@ from ....initializer import Xavier from ...block import HybridBlock from ... import nn +from .... import base class VGG(HybridBlock): @@ -94,7 +95,7 @@ def hybrid_forward(self, F, x): # Constructors def get_vgg(num_layers, pretrained=False, ctx=cpu(), - root=os.path.join('~', '.mxnet', 'models'), **kwargs): + root=os.path.join(base.data_dir(), 'models'), **kwargs): r"""VGG model from the `"Very Deep Convolutional Networks for Large-Scale Image Recognition" `_ paper. @@ -106,7 +107,7 @@ def get_vgg(num_layers, pretrained=False, ctx=cpu(), Whether to load the pretrained weights for model. ctx : Context, default CPU The context in which to load the pretrained weights. - root : str, default '~/.mxnet/models' + root : str, default $MXNET_DATA_DIR/models Location for keeping the model parameters. """ layers, filters = vgg_spec[num_layers] @@ -128,7 +129,7 @@ def vgg11(**kwargs): Whether to load the pretrained weights for model. ctx : Context, default CPU The context in which to load the pretrained weights. - root : str, default '~/.mxnet/models' + root : str, default '$MXNET_DATA_DIR/models' Location for keeping the model parameters. """ return get_vgg(11, **kwargs) @@ -143,7 +144,7 @@ def vgg13(**kwargs): Whether to load the pretrained weights for model. ctx : Context, default CPU The context in which to load the pretrained weights. - root : str, default '~/.mxnet/models' + root : str, default '$MXNET_DATA_DIR/models' Location for keeping the model parameters. """ return get_vgg(13, **kwargs) @@ -158,7 +159,7 @@ def vgg16(**kwargs): Whether to load the pretrained weights for model. ctx : Context, default CPU The context in which to load the pretrained weights. - root : str, default '~/.mxnet/models' + root : str, default '$MXNET_DATA_DIR/models' Location for keeping the model parameters. """ return get_vgg(16, **kwargs) @@ -173,7 +174,7 @@ def vgg19(**kwargs): Whether to load the pretrained weights for model. ctx : Context, default CPU The context in which to load the pretrained weights. - root : str, default '~/.mxnet/models' + root : str, default '$MXNET_DATA_DIR/models' Location for keeping the model parameters. """ return get_vgg(19, **kwargs) @@ -189,7 +190,7 @@ def vgg11_bn(**kwargs): Whether to load the pretrained weights for model. ctx : Context, default CPU The context in which to load the pretrained weights. - root : str, default '~/.mxnet/models' + root : str, default '$MXNET_DATA_DIR/models' Location for keeping the model parameters. """ kwargs['batch_norm'] = True @@ -206,7 +207,7 @@ def vgg13_bn(**kwargs): Whether to load the pretrained weights for model. ctx : Context, default CPU The context in which to load the pretrained weights. - root : str, default '~/.mxnet/models' + root : str, default '$MXNET_DATA_DIR/models' Location for keeping the model parameters. """ kwargs['batch_norm'] = True @@ -223,7 +224,7 @@ def vgg16_bn(**kwargs): Whether to load the pretrained weights for model. ctx : Context, default CPU The context in which to load the pretrained weights. - root : str, default '~/.mxnet/models' + root : str, default '$MXNET_DATA_DIR/models' Location for keeping the model parameters. """ kwargs['batch_norm'] = True @@ -240,7 +241,7 @@ def vgg19_bn(**kwargs): Whether to load the pretrained weights for model. ctx : Context, default CPU The context in which to load the pretrained weights. - root : str, default '~/.mxnet/models' + root : str, default '$MXNET_DATA_DIR/models' Location for keeping the model parameters. """ kwargs['batch_norm'] = True diff --git a/python/mxnet/test_utils.py b/python/mxnet/test_utils.py index e963d158446d..5643b158f9f3 100644 --- a/python/mxnet/test_utils.py +++ b/python/mxnet/test_utils.py @@ -484,7 +484,7 @@ def assert_almost_equal(a, b, rtol=None, atol=None, names=('a', 'b'), equal_nan= return index, rel = find_max_violation(a, b, rtol, atol) - np.set_printoptions(threshold=4, suppress=True) + np.set_printoptions(threshold=np.nan, suppress=True) msg = npt.build_err_msg([a, b], err_msg="Error %f exceeds tolerance rtol=%f, atol=%f. " " Location of maximum error:%s, a=%f, b=%f" diff --git a/tests/python/unittest/test_base.py b/tests/python/unittest/test_base.py new file mode 100644 index 000000000000..e1b4f315c456 --- /dev/null +++ b/tests/python/unittest/test_base.py @@ -0,0 +1,48 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import mxnet as mx +from mxnet.base import data_dir +from nose.tools import * +import os +import unittest +import logging +import os.path as op + +class MXNetDataDirTest(unittest.TestCase): + def setUp(self): + self.mxnet_data_dir = os.environ.get('MXNET_DATA_DIR') + if 'MXNET_DATA_DIR' in os.environ: + del os.environ['MXNET_DATA_DIR'] + + def tearDown(self): + if self.mxnet_data_dir: + os.environ['MXNET_DATA_DIR'] = self.mxnet_data_dir + else: + if 'MXNET_DATA_DIR' in os.environ: + del os.environ['MXNET_DATA_DIR'] + + def test_data_dir(self,): + self.assertEqual(data_dir(), op.join(op.expanduser('~'), '.local', 'share', 'mxnet')) + os.environ['MXNET_DATA_DIR'] = '/tmp/mxnet_data' + self.assertEqual(data_dir(), '/tmp/mxnet_data') + os.environ['XDG_DATA_HOME'] = '/blah/data' + self.assertEqual(data_dir(), '/tmp/mxnet_data') + del os.environ['MXNET_DATA_DIR'] + self.assertEqual(data_dir(), '/blah/data/mxnet') + +