diff --git a/example/notebooks/alexnet.ipynb b/example/notebooks/alexnet.ipynb index b7bb6bf266c2..c030d873cd08 100644 --- a/example/notebooks/alexnet.ipynb +++ b/example/notebooks/alexnet.ipynb @@ -29,7 +29,6 @@ }, "outputs": [], "source": [ - "%matplotlib inline\n", "import mxnet as mx" ] }, @@ -402,7 +401,7 @@ } ], "source": [ - "mx.visualization.plot_network(\"AlexNet\", softmax)" + "mx.viz.plot_network(\"AlexNet\", softmax)" ] }, { @@ -425,28 +424,32 @@ "# We set batch size for to 256\n", "batch_size = 256\n", "# We need to set correct path to image record file\n", - "# For ```mean_image```. if it doesn't exist, the iterator will generate one. Usually on normal HDD, it costs less than 10 minutes\n", + "# For ```mean_image```. if it doesn't exist, the iterator will generate one\n", + "# On HDD, single thread is able to process 800 images / sec\n", "# the input shape is in format (channel, height, width)\n", "# rand_crop option make source image randomly cropped to input_shape (3, 224, 224)\n", "# rand_mirror option make source image randomly mirrored\n", "# We use 2 threads to processing our data\n", "train_dataiter = mx.io.ImageRecordIter(\n", + " shuffle=True,\n", " path_imgrec=\"./Data/ImageNet/train.rec\",\n", " mean_img=\"./Data/ImageNet/mean_224.bin\",\n", " rand_crop=True,\n", " rand_mirror=True,\n", - " input_shape=(3, 224, 224),\n", + " data_shape=(3, 224, 224),\n", " batch_size=batch_size,\n", - " nthread=2)\n", + " prefetch_buffer=4,\n", + " preprocess_threads=2)\n", "# similarly, we can declare our validation iterator\n", "val_dataiter = mx.io.ImageRecordIter(\n", " path_imgrec=\"./Data/ImageNet/val.rec\",\n", " mean_img=\"./Data/ImageNet/mean_224.bin\",\n", " rand_crop=False,\n", " rand_mirror=False,\n", - " input_shape=(3, 224, 224),\n", + " data_shape=(3, 224, 224),\n", " batch_size=batch_size,\n", - " nthread=2)" + " prefetch_buffer=4,\n", + " preprocess_threads=2)" ] }, { @@ -531,7 +534,7 @@ "# When we use data iterator, we don't need to set y because label comes from data iterator directly\n", "# In this case, eval_data is also a data iterator\n", "# We will use accuracy to measure our model's performace\n", - "model.fit(X=train_dataiter, eval_data=val_dataiter, eval_metric='acc', verbose=True)\n", + "model.fit(X=train_dataiter, eval_data=val_dataiter, eval_metric='acc')\n", "# You need to wait for a while to get the result" ] }, diff --git a/python/mxnet/__init__.py b/python/mxnet/__init__.py index 89dd2c09da79..e9630b678ee0 100644 --- a/python/mxnet/__init__.py +++ b/python/mxnet/__init__.py @@ -12,6 +12,7 @@ from .base import MXNetError from . import base from . import ndarray +from . import name from . import symbol from . import kvstore as kv from . import io diff --git a/python/mxnet/context.py b/python/mxnet/context.py index fff45dc7b895..485d292aa203 100644 --- a/python/mxnet/context.py +++ b/python/mxnet/context.py @@ -3,23 +3,36 @@ from __future__ import absolute_import class Context(object): - """Context representing device and device id in mxnet""" + """Constructing a context. + + Parameters + ---------- + device_type : {'cpu', 'gpu'} or Context. + String representing the device type + + device_id : int (default=0) + The device id of the device, needed for GPU + + Note + ---- + Context can also be used a way to change default context. + + Examples + -------- + Switch default context example: + >>> # array on cpu + >>> cpu_array = mx.md.ones((2, 3)) + >>> # switch default context to GPU(2) + >>> with mx.Context(mx.gpu(2)): + >>> gpu_array = mx.md.ones((2, 3)) + >>> gpu_array.context + Context(device_type=gpu, device_id=2) + """ # static class variable default_ctx = None devtype2str = {1: 'cpu', 2: 'gpu'} devstr2type = {'cpu': 1, 'gpu': 2} - def __init__(self, device_type, device_id=0): - """Constructing a context. - - Parameters - ---------- - device_type : str (can be 'cpu' or 'gpu') - a string representing the device type - - device_id : int (default=0) - the device id of the device, needed for GPU - """ if isinstance(device_type, Context): self.device_typeid = device_type.device_typeid self.device_id = device_type.device_id diff --git a/python/mxnet/initializer.py b/python/mxnet/initializer.py index bd64413ca295..fa13594926f7 100644 --- a/python/mxnet/initializer.py +++ b/python/mxnet/initializer.py @@ -1,4 +1,7 @@ -# pylint: skip-file +# coding: utf-8 +"""Initialization helper for mxnet""" +from __future__ import absolute_import + import numpy as np from .base import string_types from .ndarray import NDArray @@ -36,17 +39,17 @@ def __call__(self, name, arr): self._init_zero(name, arr) else: self._init_default(name, arr) - - def _init_zero(self, name, arr): + # pylint: disable=no-self-use, missing-docstring + def _init_zero(self, _, arr): arr[:] = 0.0 - def _init_bias(self, name, arr): + def _init_bias(self, _, arr): arr[:] = 0.0 - def _init_gamma(self, name, arr): + def _init_gamma(self, _, arr): arr[:] = 1.0 - def _init_beta(self, name, arr): + def _init_beta(self, _, arr): arr[:] = 0.0 def _init_weight(self, name, arr): @@ -55,7 +58,7 @@ def _init_weight(self, name, arr): def _init_default(self, name, _): raise ValueError('Unknown initialization pattern for %s' % name) - + # pylint: enable=no-self-use, missing-docstring class Uniform(Initializer): """Initialize the weight with uniform [-scale, scale] @@ -68,8 +71,8 @@ class Uniform(Initializer): def __init__(self, scale=0.07): self.scale = scale - def _init_weight(self, name, arr): - random.uniform(-scale, scale, out=arr) + def _init_weight(self, _, arr): + random.uniform(-self.scale, self.scale, out=arr) class Normal(Initializer): @@ -81,10 +84,10 @@ class Normal(Initializer): Standard deviation for gaussian distribution. """ def __init__(self, sigma=0.01): - super().__init__(sigma = sigma) + self.sigma = sigma - def _init_weight(self, name, arr): - random.normal(0, sigma, out=arr) + def _init_weight(self, _, arr): + random.normal(0, self.sigma, out=arr) class Xavier(Initializer): @@ -95,6 +98,6 @@ def _init_weight(self, _, arr): # [in, out] for fullc shape = arr.shape fan_in, fan_out = shape[1], shape[0] - s = np.sqrt(6. / (fan_in + fan_out)) - random.uniform(-s, s, out=arr) + scale = np.sqrt(6. / (fan_in + fan_out)) + random.uniform(-scale, scale, out=arr) diff --git a/python/mxnet/io.py b/python/mxnet/io.py index cb55df71aa3f..e4e6905aba3a 100644 --- a/python/mxnet/io.py +++ b/python/mxnet/io.py @@ -1,5 +1,4 @@ # coding: utf-8 - """NDArray interface of mxnet""" from __future__ import absolute_import diff --git a/python/mxnet/kvstore.py b/python/mxnet/kvstore.py index 8f4822d85a69..ac5691d37909 100644 --- a/python/mxnet/kvstore.py +++ b/python/mxnet/kvstore.py @@ -2,6 +2,7 @@ # pylint: disable=invalid-name, global-statement """ KVStore in mxnet """ from __future__ import absolute_import + import ctypes from .ndarray import NDArray from .base import _LIB diff --git a/python/mxnet/model.py b/python/mxnet/model.py index 11e440988f51..3466bf9362d9 100644 --- a/python/mxnet/model.py +++ b/python/mxnet/model.py @@ -1,6 +1,8 @@ # pylint: disable=fixme, invalid-name, too-many-arguments, too-many-locals # pylint: disable=too-many-branches, too-many-statements, unused-argument """MXNet model module""" +from __future__ import absolute_import + import numpy as np import time import logging @@ -201,11 +203,18 @@ def _train_multi_device(symbol, ctx, input_shape, aux_params[name].copyto(w) # ky value store kv = kvstore.create() if num_device != 1 else None + opt_state_blocks = [] # If there are multiple devices, initialize the weights. for index, pair in enumerate(zip(arg_blocks, grad_blocks)): - arg, grad = pair - if kv and grad[0] is not None: - kv.init(index, arg[0]) + arg_list, grad_list = pair + if kv and grad_list[0] is not None: + kv.init(index, arg_list[0]) + # attach state direct to weight + opt_list = [optimizer.create_state(index, w) for w in arg_list] + opt_state_blocks.append(opt_list) + else: + opt_state_blocks.append(None) + # Input and output data structure data_index, label_index = _check_arguments(symbol) merged_shape = list(train_execs[0].outputs[0].shape) @@ -244,9 +253,10 @@ def _train_multi_device(symbol, ctx, input_shape, kv.push(index, grad_list) # pull back the sum, to the same locations. kv.pull(index, grad_list) - # optimize - for w, g in zip(arg_list, grad_list): - optimizer.update(index, w, g) + opt_list = opt_state_blocks[index] + # optimizea + for w, g, state in zip(arg_list, grad_list, opt_list): + optimizer.update(index, w, g, state) # evaluate at end, so out_cpu_array can lazy copy eval_metric.update(out_cpu_array, label) diff --git a/python/mxnet/name.py b/python/mxnet/name.py new file mode 100644 index 000000000000..b0c8bff52a8a --- /dev/null +++ b/python/mxnet/name.py @@ -0,0 +1,78 @@ +# coding: utf-8 +"""Automatic naming support for symbolic API.""" +from __future__ import absolute_import + +class NameManager(object): + """NameManager to do automatic naming. + + User can also inheritate this object to change naming behavior. + """ + current = None + + def __init__(self): + self._counter = {} + self._old_manager = None + + def get(self, name, hint): + """Get the canonical name for a symbol. + + This is default implementation. + When user specified a name, + the user specified name will be used. + + When user did not, we will automatically generate a + name based on hint string. + + Parameters + ---------- + name : str or None + The name user specified. + + hint : str + A hint string, which can be used to generate name. + + Returns + ------- + full_name : str + A canonical name for the user. + """ + if name: + return name + if hint not in self._counter: + self._counter[hint] = 0 + name = '%s%d' % (hint, self._counter[hint]) + self._counter[hint] += 1 + return name + + def __enter__(self): + self._old_manager = NameManager.current + NameManager.current = self + return self + + def __exit__(self, ptype, value, trace): + assert self._old_manager + NameManager.current = self._old_manager + + +class Prefix(NameManager): + """A name manager that always attach a prefix to all names. + + Examples + -------- + >>> import mxnet as mx + >>> data = mx.symbol.Variable('data') + >>> with mx.name.Prefix('mynet_'): + net = mx.symbol.FullyConnected(data, num_hidden=10, name='fc1') + >>> net.list_arguments() + ['data', 'mynet_fc1_weight', 'mynet_fc1_bias'] + """ + def __init__(self, prefix): + super(Prefix, self).__init__() + self._prefix = prefix + + def get(self, name, hint): + name = super(Prefix, self).get(name, hint) + return self._prefix + name + +# initialize the default name manager +NameManager.current = NameManager() diff --git a/python/mxnet/optimizer.py b/python/mxnet/optimizer.py index 8cc3d1b4f241..d1f0ae4ef246 100644 --- a/python/mxnet/optimizer.py +++ b/python/mxnet/optimizer.py @@ -1,4 +1,4 @@ -# pylint: disable=fixme, invalid-name +# pylint: disable=fixme, invalid-name, unused-argument """Common Optimization algorithms with regularizations.""" from .ndarray import NDArray, zeros @@ -31,7 +31,6 @@ class SGD(Optimizer): rescale_grad : float, optional rescaling factor of gradient. - """ def __init__(self, learning_rate=0.01, momentum=0.0, wd=0.0001, rescale_grad=1): @@ -41,7 +40,21 @@ def __init__(self, learning_rate=0.01, momentum=0.0, self.rescale_grad = rescale_grad self.momentums = {} - def update(self, index, weight, grad): + def create_state(self, index, weight): + """Create additional optimizer state such as momentum. + + Parameters + ---------- + weight : NDArray + The weight data + + """ + if self.momentum == 0.0: + return None + else: + return zeros(weight.shape, weight.context) + + def update(self, index, weight, grad, state): """Update the parameters. Parameters @@ -55,17 +68,20 @@ def update(self, index, weight, grad): grad : NDArray grad ndarray + state : NDArray or other objects returned by init_state + The auxiliary state used in optimization. """ # TODO(bing) implement wd_bias, wd_gamma, wd_beta assert(isinstance(weight, NDArray)) assert(isinstance(grad, NDArray)) - - if index not in self.momentums: - self.momentums[index] = zeros(grad.shape, grad.context) - mom = self.momentums[index] - mom[:] *= self.momentum - mom[:] += -self.lr * (grad * self.rescale_grad + self.wd * weight) - weight[:] += mom + if state: + mom = state + mom[:] *= self.momentum + mom[:] += -self.lr * (grad * self.rescale_grad + self.wd * weight) + weight[:] += mom + else: + assert self.momentum == 0.0 + weight[:] += -self.lr * (grad * self.rescale_grad + self.wd * weight) def create(name, rescale_grad=1, **kwargs): diff --git a/python/mxnet/symbol.py b/python/mxnet/symbol.py index a7d53f28fc1a..44318c66200c 100644 --- a/python/mxnet/symbol.py +++ b/python/mxnet/symbol.py @@ -13,6 +13,7 @@ from .base import c_array, c_str, mx_uint, py_str, string_types from .base import NDArrayHandle, ExecutorHandle, SymbolHandle from .base import check_call, ctypes2docstring +from .name import NameManager from .context import Context from .ndarray import NDArray, zeros from .executor import Executor @@ -128,6 +129,7 @@ def _compose(self, *args, **kwargs): the resulting symbol """ name = kwargs.pop('name', None) + if name: name = c_str(name) if len(args) != 0 and len(kwargs) != 0: @@ -752,6 +754,8 @@ def creator(*args, **kwargs): ' instead of keyword arguments.') s = Symbol(sym_handle) + hint = func_name.lower() + name = NameManager.current.get(name, hint) s._compose(*args, name=name, **symbol_kwargs) return s diff --git a/python/mxnet/visualization.py b/python/mxnet/visualization.py index 86fc53c37311..3992a241b69f 100644 --- a/python/mxnet/visualization.py +++ b/python/mxnet/visualization.py @@ -2,6 +2,8 @@ # pylint: disable=invalid-name, protected-access, too-many-locals, fixme # pylint: disable=unused-argument, too-many-branches, too-many-statements """Visualization module""" +from __future__ import absolute_import + from .symbol import Symbol import json import re diff --git a/tests/python/common/get_data.py b/tests/python/common/get_data.py index 270132e448b8..65e8ac59ad6f 100644 --- a/tests/python/common/get_data.py +++ b/tests/python/common/get_data.py @@ -14,18 +14,14 @@ def GetMNIST_pkl(): def GetMNIST_ubyte(): if not os.path.isdir("data/"): os.system("mkdir data/") - if not os.path.exists('data/train-images-idx3-ubyte'): - os.system("wget http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz -P data/") - os.system("gunzip data/train-images-idx3-ubyte.gz") - if not os.path.exists('data/train-labels-idx1-ubyte'): - os.system("wget http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz -P data/") - os.system("gunzip data/train-labels-idx1-ubyte.gz") - if not os.path.exists('data/t10k-images-idx3-ubyte'): - os.system("wget http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz -P data/") - os.system("gunzip data/t10k-images-idx3-ubyte.gz") - if not os.path.exists('data/t10k-labels-idx1-ubyte'): - os.system("wget http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz -P data/") - os.system("gunzip data/t10k-labels-idx1-ubyte.gz") + if (not os.path.exists('data/train-images-idx3-ubyte')) or \ + (not os.path.exists('data/train-labels-idx1-ubyte')) or \ + (not os.path.exists('data/t10k-images-idx3-ubyte')) or \ + (not os.path.exists('data/t10k-labels-idx1-ubyte')): + os.system("wget http://webdocs.cs.ualberta.ca/~bx3/data/mnist.zip -P data/") + os.chdir("./data") + os.system("unzip -u mnist.zip") + os.chdir("..") # download cifar def GetCifar10(): @@ -34,5 +30,5 @@ def GetCifar10(): if not os.path.exists('data/cifar10.zip'): os.system("wget http://webdocs.cs.ualberta.ca/~bx3/data/cifar10.zip -P data/") os.chdir("./data") - os.system("unzip cifar10.zip") + os.system("unzip -u cifar10.zip") os.chdir("..") diff --git a/tests/python/train/test_mlp.py b/tests/python/train/test_mlp.py index 40304187a5fa..b0849a3e81d9 100644 --- a/tests/python/train/test_mlp.py +++ b/tests/python/train/test_mlp.py @@ -18,7 +18,7 @@ num_round = 4 prefix = './mlp' -model = mx.model.FeedForward(softmax, mx.cpu(), +model = mx.model.FeedForward(softmax, [mx.cpu()] * 2, num_round=num_round, learning_rate=0.01, wd=0.0004, momentum=0.9)