Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
set MXNET_DATA_DIR as base for downloaded models through base.data_dir()
Browse files Browse the repository at this point in the history
  • Loading branch information
larroy committed Jul 10, 2018
1 parent 0bec94b commit addd518
Show file tree
Hide file tree
Showing 16 changed files with 157 additions and 69 deletions.
5 changes: 5 additions & 0 deletions docs/faq/env_var.md
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,11 @@ When USE_PROFILER is enabled in Makefile or CMake, the following environments ca
- Values: String ```(default='https://apache-mxnet.s3-accelerate.dualstack.amazonaws.com/'```
- The repository url to be used for Gluon datasets and pre-trained models.

* MXNET_DATA_DIR
- Data directory in the filesystem for storage, for example when downloading gluon models.
- Default in *nix is .local/share/mxnet APPDATA/mxnet in windows, XDG_DATA_HOME/mxnet if
XDG_DATA_HOME is set.

Settings for Minimum Memory Usage
---------------------------------
- Make sure ```min(MXNET_EXEC_NUM_TEMP, MXNET_GPU_WORKER_NTHREADS) = 1```
Expand Down
27 changes: 27 additions & 0 deletions python/mxnet/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import warnings
import inspect
import numpy as np
import platform
from . import libinfo
warnings.filterwarnings('default', category=DeprecationWarning)

Expand All @@ -47,6 +48,32 @@
integer_types = (int, long, np.int32, np.int64)
py_str = lambda x: x


def data_dir_default():
"""
:return: default data directory depending on the platform and environment variables
"""
system = platform.system()
if system == 'Windows':
return os.environ.get('APPDATA')
elif system == 'Linux' or system == 'Darwin':
if 'XDG_DATA_HOME'in os.environ:
return os.path.join(os.environ['XDG_DATA_HOME'], 'mxnet')
else:
return os.path.join(os.path.expanduser("~"), '.local', 'share', 'mxnet')
else:
return os.path.join(os.path.expanduser("~"), '.mxnet')


def data_dir():
"""
:return: data directory in the filesystem for storage, for example when downloading models
"""
return os.getenv('MXNET_DATA_DIR', data_dir_default())


class _NullType(object):
"""Placeholder for arguments"""
def __repr__(self):
Expand Down
9 changes: 5 additions & 4 deletions python/mxnet/contrib/text/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from . import vocab
from ... import ndarray as nd
from ... import registry
from ... import base


def register(embedding_cls):
Expand Down Expand Up @@ -496,7 +497,7 @@ class GloVe(_TokenEmbedding):
----------
pretrained_file_name : str, default 'glove.840B.300d.txt'
The name of the pre-trained token embedding file.
embedding_root : str, default os.path.join('~', '.mxnet', 'embeddings')
embedding_root : str, default $MXNET_DATA_DIR/embeddings
The root directory for storing embedding-related files.
init_unknown_vec : callback
The callback used to initialize the embedding vector for the unknown token.
Expand Down Expand Up @@ -541,7 +542,7 @@ def _get_download_file_name(cls, pretrained_file_name):
return archive

def __init__(self, pretrained_file_name='glove.840B.300d.txt',
embedding_root=os.path.join('~', '.mxnet', 'embeddings'),
embedding_root=os.path.join(base.data_dir(),'embeddings'),
init_unknown_vec=nd.zeros, vocabulary=None, **kwargs):
GloVe._check_pretrained_file_names(pretrained_file_name)

Expand Down Expand Up @@ -600,7 +601,7 @@ class FastText(_TokenEmbedding):
----------
pretrained_file_name : str, default 'wiki.en.vec'
The name of the pre-trained token embedding file.
embedding_root : str, default os.path.join('~', '.mxnet', 'embeddings')
embedding_root : str, default $MXNET_DATA_DIR/embeddings
The root directory for storing embedding-related files.
init_unknown_vec : callback
The callback used to initialize the embedding vector for the unknown token.
Expand Down Expand Up @@ -642,7 +643,7 @@ def _get_download_file_name(cls, pretrained_file_name):
return '.'.join(pretrained_file_name.split('.')[:-1])+'.zip'

def __init__(self, pretrained_file_name='wiki.simple.vec',
embedding_root=os.path.join('~', '.mxnet', 'embeddings'),
embedding_root=os.path.join(base.data_dir(), 'embeddings'),
init_unknown_vec=nd.zeros, vocabulary=None, **kwargs):
FastText._check_pretrained_file_names(pretrained_file_name)

Expand Down
11 changes: 5 additions & 6 deletions python/mxnet/gluon/contrib/data/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,7 @@
from ...data import dataset
from ...utils import download, check_sha1, _get_repo_file_url
from ....contrib import text
from .... import nd

from .... import nd, base

class _LanguageModelDataset(dataset._DownloadedDataset): # pylint: disable=abstract-method
def __init__(self, root, namespace, vocabulary):
Expand Down Expand Up @@ -116,7 +115,7 @@ class WikiText2(_WikiText):
Parameters
----------
root : str, default '~/.mxnet/datasets/wikitext-2'
root : str, default $MXNET_DATA_DIR/datasets/wikitext-2
Path to temp folder for storing data.
segment : str, default 'train'
Dataset segment. Options are 'train', 'validation', 'test'.
Expand All @@ -127,7 +126,7 @@ class WikiText2(_WikiText):
The sequence length of each sample, regardless of the sentence boundary.
"""
def __init__(self, root=os.path.join('~', '.mxnet', 'datasets', 'wikitext-2'),
def __init__(self, root=os.path.join(base.data_dir(), 'datasets', 'wikitext-2'),
segment='train', vocab=None, seq_len=35):
self._archive_file = ('wikitext-2-v1.zip', '3c914d17d80b1459be871a5039ac23e752a53cbe')
self._data_file = {'train': ('wiki.train.tokens',
Expand All @@ -154,7 +153,7 @@ class WikiText103(_WikiText):
Parameters
----------
root : str, default '~/.mxnet/datasets/wikitext-103'
root : str, default $MXNET_DATA_DIR/datasets/wikitext-103
Path to temp folder for storing data.
segment : str, default 'train'
Dataset segment. Options are 'train', 'validation', 'test'.
Expand All @@ -164,7 +163,7 @@ class WikiText103(_WikiText):
seq_len : int, default 35
The sequence length of each sample, regardless of the sentence boundary.
"""
def __init__(self, root=os.path.join('~', '.mxnet', 'datasets', 'wikitext-103'),
def __init__(self, root=os.path.join(base.data_dir(), 'datasets', 'wikitext-103'),
segment='train', vocab=None, seq_len=35):
self._archive_file = ('wikitext-103-v1.zip', '0aec09a7537b58d4bb65362fee27650eeaba625a')
self._data_file = {'train': ('wiki.train.tokens',
Expand Down
18 changes: 9 additions & 9 deletions python/mxnet/gluon/data/vision/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@

from .. import dataset
from ...utils import download, check_sha1, _get_repo_file_url
from .... import nd, image, recordio
from .... import nd, image, recordio, base


class MNIST(dataset._DownloadedDataset):
Expand All @@ -40,7 +40,7 @@ class MNIST(dataset._DownloadedDataset):
Parameters
----------
root : str, default '~/.mxnet/datasets/mnist'
root : str, default $MXNET_DATA_DIR/datasets/mnist
Path to temp folder for storing data.
train : bool, default True
Whether to load the training or testing set.
Expand All @@ -51,7 +51,7 @@ class MNIST(dataset._DownloadedDataset):
transform=lambda data, label: (data.astype(np.float32)/255, label)
"""
def __init__(self, root=os.path.join('~', '.mxnet', 'datasets', 'mnist'),
def __init__(self, root=os.path.join(base.data_dir(), 'datasets', 'mnist'),
train=True, transform=None):
self._train = train
self._train_data = ('train-images-idx3-ubyte.gz',
Expand Down Expand Up @@ -101,7 +101,7 @@ class FashionMNIST(MNIST):
Parameters
----------
root : str, default '~/.mxnet/datasets/fashion-mnist'
root : str, default $MXNET_DATA_DIR/datasets/fashion-mnist'
Path to temp folder for storing data.
train : bool, default True
Whether to load the training or testing set.
Expand All @@ -112,7 +112,7 @@ class FashionMNIST(MNIST):
transform=lambda data, label: (data.astype(np.float32)/255, label)
"""
def __init__(self, root=os.path.join('~', '.mxnet', 'datasets', 'fashion-mnist'),
def __init__(self, root=os.path.join(base.data_dir(), 'datasets', 'fashion-mnist'),
train=True, transform=None):
self._train = train
self._train_data = ('train-images-idx3-ubyte.gz',
Expand All @@ -134,7 +134,7 @@ class CIFAR10(dataset._DownloadedDataset):
Parameters
----------
root : str, default '~/.mxnet/datasets/cifar10'
root : str, default $MXNET_DATA_DIR/datasets/cifar10
Path to temp folder for storing data.
train : bool, default True
Whether to load the training or testing set.
Expand All @@ -145,7 +145,7 @@ class CIFAR10(dataset._DownloadedDataset):
transform=lambda data, label: (data.astype(np.float32)/255, label)
"""
def __init__(self, root=os.path.join('~', '.mxnet', 'datasets', 'cifar10'),
def __init__(self, root=os.path.join(base.data_dir(), 'datasets', 'cifar10'),
train=True, transform=None):
self._train = train
self._archive_file = ('cifar-10-binary.tar.gz', 'fab780a1e191a7eda0f345501ccd62d20f7ed891')
Expand Down Expand Up @@ -197,7 +197,7 @@ class CIFAR100(CIFAR10):
Parameters
----------
root : str, default '~/.mxnet/datasets/cifar100'
root : str, default $MXNET_DATA_DIR/datasets/cifar100
Path to temp folder for storing data.
fine_label : bool, default False
Whether to load the fine-grained (100 classes) or coarse-grained (20 super-classes) labels.
Expand All @@ -210,7 +210,7 @@ class CIFAR100(CIFAR10):
transform=lambda data, label: (data.astype(np.float32)/255, label)
"""
def __init__(self, root=os.path.join('~', '.mxnet', 'datasets', 'cifar100'),
def __init__(self, root=os.path.join(base.data_dir(), 'datasets', 'cifar100'),
fine_label=False, train=True, transform=None):
self._train = train
self._archive_file = ('cifar-100-binary.tar.gz', 'a0bb982c76b83111308126cc779a992fa506b90b')
Expand Down
17 changes: 9 additions & 8 deletions python/mxnet/gluon/model_zoo/model_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,10 @@
__all__ = ['get_model_file', 'purge']
import os
import zipfile
import logging

from ..utils import download, check_sha1
from ... import base

_model_sha1 = {name: checksum for checksum, name in [
('44335d1f0046b328243b32a26a4fbd62d9057b45', 'alexnet'),
Expand Down Expand Up @@ -68,7 +70,7 @@ def short_hash(name):
raise ValueError('Pretrained model for {name} is not available.'.format(name=name))
return _model_sha1[name][:8]

def get_model_file(name, root=os.path.join('~', '.mxnet', 'models')):
def get_model_file(name, root=os.path.join(base.data_dir(), 'models')):
r"""Return location for the pretrained on local file system.
This function will download from online model zoo when model cannot be found or has mismatch.
Expand All @@ -78,7 +80,7 @@ def get_model_file(name, root=os.path.join('~', '.mxnet', 'models')):
----------
name : str
Name of the model.
root : str, default '~/.mxnet/models'
root : str, default $MXNET_DATA_DIR/models
Location for keeping the model parameters.
Returns
Expand All @@ -95,12 +97,11 @@ def get_model_file(name, root=os.path.join('~', '.mxnet', 'models')):
if check_sha1(file_path, sha1_hash):
return file_path
else:
print('Mismatch in the content of model file detected. Downloading again.')
logging.warn('Mismatch in the content of model file detected. Downloading again.')
else:
print('Model file is not found. Downloading.')
logging.info('Model file not found. Downloading to %s.', file_path)

if not os.path.exists(root):
os.makedirs(root)
os.makedirs(root, exist_ok=True)

zip_file_path = os.path.join(root, file_name+'.zip')
repo_url = os.environ.get('MXNET_GLUON_REPO', apache_repo_url)
Expand All @@ -118,12 +119,12 @@ def get_model_file(name, root=os.path.join('~', '.mxnet', 'models')):
else:
raise ValueError('Downloaded file has different hash. Please try again.')

def purge(root=os.path.join('~', '.mxnet', 'models')):
def purge(root=os.path.join(base.data_dir(), 'models')):
r"""Purge all pretrained model files in local file store.
Parameters
----------
root : str, default '~/.mxnet/models'
root : str, default '$MXNET_DATA_DIR/models'
Location for keeping the model parameters.
"""
root = os.path.expanduser(root)
Expand Down
2 changes: 1 addition & 1 deletion python/mxnet/gluon/model_zoo/vision/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def get_model(name, **kwargs):
Number of classes for the output layer.
ctx : Context, default CPU
The context in which to load the pretrained weights.
root : str, default '~/.mxnet/models'
root : str, default '$MXNET_DATA_DIR/models'
Location for keeping the model parameters.
Returns
Expand Down
5 changes: 3 additions & 2 deletions python/mxnet/gluon/model_zoo/vision/alexnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from ....context import cpu
from ...block import HybridBlock
from ... import nn
from .... import base

# Net
class AlexNet(HybridBlock):
Expand Down Expand Up @@ -68,7 +69,7 @@ def hybrid_forward(self, F, x):

# Constructor
def alexnet(pretrained=False, ctx=cpu(),
root=os.path.join('~', '.mxnet', 'models'), **kwargs):
root=os.path.join(base.data_dir(), 'models'), **kwargs):
r"""AlexNet model from the `"One weird trick..." <https://arxiv.org/abs/1404.5997>`_ paper.
Parameters
Expand All @@ -77,7 +78,7 @@ def alexnet(pretrained=False, ctx=cpu(),
Whether to load the pretrained weights for model.
ctx : Context, default CPU
The context in which to load the pretrained weights.
root : str, default '~/.mxnet/models'
root : str, default $MXNET_DATA_DIR/models
Location for keeping the model parameters.
"""
net = AlexNet(**kwargs)
Expand Down
13 changes: 7 additions & 6 deletions python/mxnet/gluon/model_zoo/vision/densenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from ...block import HybridBlock
from ... import nn
from ...contrib.nn import HybridConcurrent, Identity
from .... import base

# Helpers
def _make_dense_block(num_layers, bn_size, growth_rate, dropout, stage_index):
Expand Down Expand Up @@ -122,7 +123,7 @@ def hybrid_forward(self, F, x):

# Constructor
def get_densenet(num_layers, pretrained=False, ctx=cpu(),
root=os.path.join('~', '.mxnet', 'models'), **kwargs):
root=os.path.join(base.data_dir(), 'models'), **kwargs):
r"""Densenet-BC model from the
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_ paper.
Expand All @@ -134,7 +135,7 @@ def get_densenet(num_layers, pretrained=False, ctx=cpu(),
Whether to load the pretrained weights for model.
ctx : Context, default CPU
The context in which to load the pretrained weights.
root : str, default '~/.mxnet/models'
root : str, default $MXNET_DATA_DIR/models
Location for keeping the model parameters.
"""
num_init_features, growth_rate, block_config = densenet_spec[num_layers]
Expand All @@ -154,7 +155,7 @@ def densenet121(**kwargs):
Whether to load the pretrained weights for model.
ctx : Context, default CPU
The context in which to load the pretrained weights.
root : str, default '~/.mxnet/models'
root : str, default '$MXNET_DATA_DIR/models'
Location for keeping the model parameters.
"""
return get_densenet(121, **kwargs)
Expand All @@ -169,7 +170,7 @@ def densenet161(**kwargs):
Whether to load the pretrained weights for model.
ctx : Context, default CPU
The context in which to load the pretrained weights.
root : str, default '~/.mxnet/models'
root : str, default '$MXNET_DATA_DIR/models'
Location for keeping the model parameters.
"""
return get_densenet(161, **kwargs)
Expand All @@ -184,7 +185,7 @@ def densenet169(**kwargs):
Whether to load the pretrained weights for model.
ctx : Context, default CPU
The context in which to load the pretrained weights.
root : str, default '~/.mxnet/models'
root : str, default '$MXNET_DATA_DIR/models'
Location for keeping the model parameters.
"""
return get_densenet(169, **kwargs)
Expand All @@ -199,7 +200,7 @@ def densenet201(**kwargs):
Whether to load the pretrained weights for model.
ctx : Context, default CPU
The context in which to load the pretrained weights.
root : str, default '~/.mxnet/models'
root : str, default '$MXNET_DATA_DIR/models'
Location for keeping the model parameters.
"""
return get_densenet(201, **kwargs)
Loading

0 comments on commit addd518

Please sign in to comment.