diff --git a/.travis.yml b/.travis.yml index a0f86a2bfdf8..88c837798c01 100644 --- a/.travis.yml +++ b/.travis.yml @@ -3,6 +3,7 @@ sudo: false # Enabling test on Linux and OS X os: - linux + - osx # Use Build Matrix to do lint and build seperately env: diff --git a/README.md b/README.md index b7afa97dd8c4..d19b993792ad 100644 --- a/README.md +++ b/README.md @@ -5,19 +5,18 @@ [![Documentation Status](https://readthedocs.org/projects/mxnet/badge/?version=latest)](http://mxnet.readthedocs.org/en/latest/) [![Hex.pm](https://img.shields.io/hexpm/l/plug.svg)]() -MXNet is a deep learning framework designed for both *efficiency* and *flexibility*. It -aims for people +MXNet is a deep learning framework designed for both *efficiency* and *flexibility*. +It allows you to mix the [flavors](http://mxnet.readthedocs.org/en/latest/program_model.html) of +deep learning programs together to maximize the efficiency and your productivity. -- Who want to apply deep learning for applications. One can use only several lines of codes - to create and train a neural network with high efficiency. Check our - [examples](example) for more details. -- Who want to use it for research on deep learning. MXNet provides flexible - programming interface for rapid prototyping. For example, check our - [tutorials for Python](http://mxnet.readthedocs.org/en/latest/python/tutorial.html) +What's New +---------- +* [Note on Programming Models for Deep Learning](http://mxnet.readthedocs.org/en/latest/program_model.html) Contents -------- * [Documentation](http://mxnet.readthedocs.org/en/latest/) +* [Code Examples](example) * [Build Instruction](doc/build.md) * [Features](#features) * [License](#license) diff --git a/doc/img/comp_grad_graph.png b/doc/img/comp_grad_graph.png new file mode 100644 index 000000000000..7b19bb45f703 Binary files /dev/null and b/doc/img/comp_grad_graph.png differ diff --git a/doc/img/comp_graph.png b/doc/img/comp_graph.png new file mode 100644 index 000000000000..c5dc88b2b882 Binary files /dev/null and b/doc/img/comp_graph.png differ diff --git a/doc/img/comp_graph_folded.png b/doc/img/comp_graph_folded.png new file mode 100644 index 000000000000..723ca7ad6c58 Binary files /dev/null and b/doc/img/comp_graph_folded.png differ diff --git a/doc/index.md b/doc/index.md index 12f1be6b8591..866bab6c3e83 100644 --- a/doc/index.md +++ b/doc/index.md @@ -15,6 +15,7 @@ User Guide Developer Guide --------------- +* [Programming Models for Deep Learning](program_model.md) * [Developer Documents](developer-guide/index.md) * [Environment Variables for MXNet](env_var.md) * [Contributor Guideline](contribute.md) diff --git a/doc/program_model.md b/doc/program_model.md new file mode 100644 index 000000000000..5fa1bc30bf75 --- /dev/null +++ b/doc/program_model.md @@ -0,0 +1,412 @@ +Programming Models for Deep Learning +==================================== +There are a lot of deep learning libraries, each comes with its own flavor. +How can each flavor introduced by each library provide advantage or drawbacks in terms of system optimization and user experience? +This article aims to compare these flavors in terms of programming models, discuss the fundenmental advantage and drawbacks +introduced by these model, and how we can learn from them. + +We will focus on the programming model itself instead of the implementations. So this article is not about benchmarking +deep learning libaries. Instead, we will divide the libraries into several categories in terms of what user interface they offer, +and discuss how these style of interface will affect performance and flexibility of deep learning programs. +The dicussion in this article may not be specific to deep learning, but we will keep deep learning applications as our use-cases and goal of optimization. + +Symbolic vs Imperative Programs +------------------------------- +This is the first section to get started, the first thing we are going to compare is symbolic style programs vs imperative style programs. +If you are a python or c++ programmer, it is likely you are already familar with imperative programs. +Imperative style programs conduct the computation as we run them. Most code you will write in python is imperative, +for example, the following numpy snippet. +```python +import numpy as np +a = np.ones(10) +b = np.ones(10) * 2 +c = b * a +d = c + 1 +``` +When the programs execute to ```c = b * a```, it runs the actual computation. Symbolic programs are bit different. +The following snippet is an equivalent symbolic style program you can write to achive the same goal of calculating ```d```. +```python +A = Variable('A') +B = Variable('B') +C = B * A +D = C + Constant(1) +# compiles the function +f = compile(D) +d = f(A=np.ones(10), B=np.ones(10)*2) +``` +The difference in symbolic programs is when ```C = B * A``` is executed, there is no actual computation happening. +Instead, these operations generates a computation graph (symbolic graph) that represents the computation it described. +The following picture gives a computation graph to compute ```D```. + +![Comp Graph](img/comp_graph.png) + +Most symbolic style programs will contain, either explicitly or implicitly, a ```compile``` step. +This converts the computation graph into a function that can be called. +Then the real computation happens at the last step of the code. The major characteristic of symbolic programs +is the clear seperation between the computation graph defintion step, and the compile, running step. + +Examples of imperative style deep learning libraries includes Torch, Chainer, Minerva. +While the example of symbolic style deep learning libraries include Theano, CGT. +The libraries that uses configuration files like cxxnet, caffe can also be viewed as symbolic style libaries. +Where the configuration file content defines the computation graph. + +Now you know the two different programming models, let us start to compare them! + +### Imperative Programs are More Flexible + +This is a general statement that may not apply strictly, but indeed imperative programs are usually more flexible than symbolic programs. +If you are writing an imperative style programs in python, you are writing in python. However, if you are writing an symbolic program, +it is different. Consider the following imperative program, think how you can translate this into a symbolic program. +```python +a = 2 +b = a + 1 +d = np.zeros(10) +for i in range(d): + d += np.zeros(10) +``` +You will find it is actually not easy, because there is a python for-loop that may not readily supported by the symbolic API. +If you are writing a symbolic programs in python, you are NOT writing in python. +Instead, you actually write a domain specific language defined by the symbolic API. +The symbolic APIs are more powerful version of DSL that generates the computation graphs or configuration of neuralnets. +In that sense, the config-file input libraries are all symbolic. + +Because imperative programs are actually more ```native``` than the symbolic ones, it is easier to use native language features +and inject them into computation flow. Such as printing out the values in the middle of comptuation, and use conditioning and loop in host language. + +### Symbolic Programs are More Efficient + +As we can see from the discussion in previous section, imperative programs are usually more flexible and native to the host language. +Why larger portion of deep learning libraries chosed to be symbolic instead? The main reason is efficiency, both in terms of memory and runtime. +Let us consider the same toy example used in the beginning of this section. + +```python +import numpy as np +a = np.ones(10) +b = np.ones(10) * 2 +c = b * a +d = c + 1 +... +``` + +![Comp Graph](img/comp_graph.png) + +Assume each cell in the array cost 8 bytes. How many memory do we need to cost if we are going to execute the above program in python console? +Let us do some math, we need memory for 4 arrays of size 10, that means we will need ```4 * 10 * 8 = 320``` bytes. On the other hand, +to execute the computation graph, we can re-use memory of C and D, to do the last computation in-place, this will give us ```3 * 10 * 8 = 240``` +bytes instead. + +Symbolic programs are more ***restricted***. When the user call ```compile``` on D, the user tells the system that only the value of +```D``` is needed. The intermediate values of computation, in our case ```C``` is invisible to the user. +This allows the symbolic programs to safely re-use the memory to do in-place computaion. + +Imperative programs, on the other hand, need to ***be prepared for all possible futures***. If the above programs is executed in a python console, +there is a possibility that any of these variables could be used in the future, this prevents the system to share the memory space of these variables. + +Of course this argument is a bit idealized, since garbage collection can happen in imperative programs when things runs out of scope, and memory could be re-used. +However, the constraint to be "prepared for all possible futures" indeed happens, and limits the optimizations we can do. This holds for non-trival cases such +as gradient calculation, which we will be discussing in next section. + +Another optimization that symbolic programs can do is operation folding. In the above programs, the multiplication and addition can be folded into one operation. +Which is represented in the following graph. This means one GPU kernel will be executed(instead of two) if the computation runs on GPU. +This is actually what we will do to hand crafted operations in optimized libraries such as cxxnet, caffe. Doing so will improve the computation efficiency. + +![Comp Graph Folded](img/comp_graph_folded.png) + +We cannot do that in imperative programs. Because the intermediate value can be reference +some point in the future. The reason that such optimization is possible in symbolic programs, is that we get the entire computation graph, and a clear +boundary on which value is needed and which is not. While imperative programs only operates on local operations and do not have such a clear boundary. + +### Case Study on Backprop and AutoDiff + +In this section, we will compare the two programing models on the problem of auto differentiation, or backpropagation. Gradient calculation is actually +the problem that all the deep learning library need to solve. It is possible to do gradient calculation in both imperative and symbolic style. + +Let us start with the imperative programs. The following snippet is a minimum python code that does automatic differentiation on the toy example we discussed. +```python +class array(object) : + """Simple Array object that support autodiff.""" + def __init__(self, value, name=None): + self.value = value + if name: + self.grad = lambda g : {name : g} + + def __add__(self, other): + assert isinstance(other, int) + ret = array(self.value + other) + ret.grad = lambda g : self.grad(g) + return ret + + def __mul__(self, other): + assert isinstance(other, array) + ret = array(self.value * other.value) + def grad(g): + x = self.grad(g * other.value) + x.update(other.grad(g * self.value)) + return x + ret.grad = grad + return ret + +# some examples +a = array(1, 'a') +b = array(2, 'b') +c = b * a +d = c + 1 +print d.value +print d.grad(1) +# Results +# 3 +# {'a': 2, 'b': 1} +``` + +In the above program, each array object contains a grad function(it is actually a closure). +When we run ```d.grad```, it recursively invoke grad function of its inputs, backprops the gradient value back, +returns the gradient value of each inputs. This may looks a bit complicated. Let us consider the gradient calculation for +symbolic programs. The program below is an example of doing symbolic gradient calculation of the same task. + +```python +A = Variable('A') +B = Variable('B') +C = B * A +D = C + Constant(1) +# get gradient node. +gA, gB = D.grad(wrt=[A, B]) +# compiles the gradient function. +f = compile([gA, gB]) +grad_a, grad_b = f(A=np.ones(10), B=np.ones(10)*2) +``` + +The grad function of D generate a backward computation graph, and return a gradient node ```gA, gB```. +They corresponds to the red nodes in the following figure. + +![Comp Graph Folded](img/comp_grad_graph.png) + +What the imperative program did was actually the same as the symbolic way. It implicitly saves a backward +computation graph in the grad closure. When we invoked the ```d.grad```, we start from ```g[D]```, +backtrace the graph to compute the gradient and collect the results back. + +So we can find that in fact the gradient calculation in both symbolic and imperative programming follows the same +pattern. What is the difference between the two then? Again recall the "have to prepared for all possibe futures" +requirement of imperative programs. If we are making an array library that support automatic differentiation, +we have to keep the grad closure along with the computaiton. This means all the history variables cannot be +garbage collected because they are referenced by variable ```d ``` via function closure. +Now, what if when we only want to compute the value of d, but do not want the gradient value? + +In symbolic programming, user declares the need by ```f=compiled([D])``` instead. It also declares the boundary +of computation, telling the system I only want to compute the forward pass. As a result, the system can +free the memory of previous results, and share the memory between inputs and outputs. + +Imagine now we are not running this toy example, but doing instead a deep neural net with ```n``` layers. +If we are only running forward pass, but not backward(gradient) pass, we will only need to allocate 2 copies of +temperal space to store values of intermediate layers, instead of ```n``` copies of them. +However because the imperative programs need to be prepared for the possible futures of getting gradient, +the intermediate values have to be stored, which requires ```n``` copies of temporal space. + +As we can see the level of optimization comes with the restrictions of what user can do. The idea of symbolic +programs is ask user to clearly specify the boundary of computation by compile or its equivalence. +While the imperative programs prepares for all possible futures. The symbolic programs get a natural advantage +by seeing more on what user wants and what user do not want. + +Of course we can also enhance the imperative programs to impose restrictions. For example, one solution to above +problem is to introduce a context variable. We can introduce a no gradient context variable, +to switch the gradient calculation off. This brings a bit more restriction into the imperative programs, +in trading for efficiency. + +```python +with context.NoGradient(): + a = array(1, 'a') + b = array(2, 'b') + c = b * a + d = c + 1 +``` + +However, the above example still have many possible futures, which means we cannot do the inplace calculation +to re-use the memory in forward pass(a trick commonly used to reduce GPU memory usage). +The techniques introduced in this section generates explicit backward pass. +On some of the toolkits such as caffe, cxxnet. Backprop is done implicitly on the same graph. +The discussions of this section also applies to these cases as well. + +Most configuration file based libraries such as cxxnet, caffe are designed for one or two generic requirement. +Get the activation of each layer, or get gradient of all the weights. Same problem stays for these libraries, +the more generic operations the library have to support, the less optimization(memory sharing) we can do, based on the same data structure. + +So as you can see the trade-off between restriction and flexibility stays for most cases. + +### Model Checkpoint + +Being able to save a model and load it back later is important for most users. There are different ways to ```save``` your work. +Normally, to save a neural net, we need to save two things, a net configuration about structure of the neural net, and weights of neural net. + +Being able to checkpoint the configuration is a plus for symbolic programs. Because the symbolic construction phase do not contain computation, +we can directly serialize the computation graph, and load it back later, this solves the save configuration problem without introducing an additional layer. + +```python +A = Variable('A') +B = Variable('B') +C = B * A +D = C + Constant(1) +D.save('mygraph') +... +D2 = load('mygraph') +f = compile([D2]) +# more operations +... +``` + +Because imperative programs executes as it describes the computation. We will have to save the code itself as the ```configuration```, or build another +configuration layer on top of the imperative language. + +### Parameter Update + +Most symbolic programs are data flow(computation) graphs. Dataflow graph can be used to descrie computation conveniently. +However, it is not obvious how to use data flow graph to describe parameter updates, because parameter updates introduces mutation, +which is not concept of data flow. What most symbolic programs do is to introduce a special update statement, to update some persistent +states of the programs. + +It is usually easier to write the parameter updates in imperative styles, especially when we need multiple updates that relates to each other. +For symbolic programs, the update statement is also executed as we call them. So in that sense, most existing symbolic deep learning libraries +also falls back to the imperative way to perform the updates, while using the symbolic way to do the gradient calculation. + +### There is no Strict Boundary + +We have made the comparison between two programming styles. Some of the arguments made may not be strictly true, and there is no clear boundaries between +the programing styles. For example, we can make a (JIT)compiler of python to compile imperative python programs, which gives us some of the advantage of global +information hold in the symbolic programs. However, most of the principles holds true in general, and these constraints apply when we are making a deep learning +libraries. + + +Big vs Small Operations +----------------------- +Now we have pass through the battlefield between symbolic and imperative programs. Let us start to talk about the operations supported by deep learning libraries. +Usually there are two types of operations supported by different deep learning libraries. +- The big layer operations such as FullyConnected, BatchNormalize +- The small operations such as elementwise addition, multiplications. +The libraries like cxxnet, caffe support layer level operations. While the libraries like Theano, Minerva support fine grained operations. + +### Smaller Operations can be More Flexible +This is quite natural, in a sense that we can always use smaller operations to compose bigger operations. +For example, the sigmoid unit can be simply be composed by division and exponential. +```python +sigmoid(x) = 1.0 / (1.0 + exp(-x)) +``` +If we have the smaller operations as building blocks, we can express most of the problems we want. +For readers who are more familar with cxxnet, caffe style layers. These operations is not different from a layer, except that they are smaller. +```python +SigmoidLayer(x) = EWiseDivisionLayer(1.0, AddScalarLayer(ExpLayer(-x), 1.0)) +``` +So the above expression becomes composition of three layers, with each defines their forward and backward(gradient) function. +This offers us an advantage to build new layers quickly, because we only need to compose these things together. + +### Big Operations are More Efficient +As you can see directly composing up sigmoid layers means we need to have three layers of operation, instead of one. +```python +SigmoidLayer(x) = EWiseDivisionLayer(1.0, AddScalarLayer(ExpLayer(-x), 1.0)) +``` +This will create overhead in terms of computation and memory (which could be optimized, with cost). + +So the libraries like cxxnet, caffe take a different approach. To support more coarse grained operations +such as BatchNormalization, and the SigmoidLayer directly. In each of these layers, the calculation kernel is handcrafted +with one or only some CUDA kernel launches. This brings more efficiency to these implementations. + +### Compilation and Optimization + +Can the small operations be optimized? Of course they can. This comes to the system optimization part of the compilation engine. +There are two types of optimization that can be done on the computation graph +- The memory allocation optimization, to reuse memory of the intermediate computations. +- Operator fusion, to detect subgraph pattern such as the sigmoid and fuse them into a bigger operation kernel. +The memory allocation optimization was actually not restricted to small operations graphs, but can also be applied to bigger operations graph as well. + +However these optimization may not be essential for bigger operation libraries like cxxnet, caffe. As you never find the compilation step in them. +Actually there is a ```compilation step```, that basically translate the layers into a fixed forward, backprop execution plan, by running each operation one by one. + +For computation graphs with smaller operations, these optimizations are crucial for performance. Because the operations are small, there are many subgraph patterns +that can be matched. Also because the final generated operations may not be able to enumerated, an explicit recompilation of the kernels is required, as opposed to +the fixed amount of pre-compiled kernels in the big operation libraries. This is the cause of compilation overhead of the symbolic libraries that support small operations. +The requirement of compilation optimization also creates overhead of engineering for the libraries that solely support smaller operations. + +Like in the symbolic vs imperative case. The bigger operation libraries "cheat" by asking user to provide restrictions(to the common layer provided), +so user is actually the one that does the subgraph matching. This removes the compilation overhead to the real brain, which is usually not too bad. + +### Expression Template and Statically Typed Language + +As we can see we always have a need to write small operations and compose them together. +Libraries like caffe use hand-carfted kernels to build up these bigger blocks. Otheriwse user have to compose up smaller operations from python side. + +Actually, there is a third choice, that works pretty well. This is called expression template. Basically, the idea is to use template programming to +generate genric kernels from expression tree at compile time. You can refer to the [Expression Template Tutorial](https://github.com/dmlc/mshadow/blob/master/guide/exp-template/README.md) +for more details. CXXNet is a library that makes extensive use of expression template, this enables much shorter and more readable code, with matched +peformance with hand crafted kernels. + +The difference between expression template and python kernel generation is that the expression evaluation is done at compile time of c++, with a existing type, +so there is no additional runtime overhead. This is also in princpile possible with other statically typed language that support template, +however we have only seen this trick in C++ so far. + +The expression template libraries creates a middle ground between python operations and hand crafted big kernels. To allow C++ users to craft efficient big +operations by composing smaller operations together. Which is also a choice worth considering. + +Mix The Flavors Together +------------------------ +Now we have compared the programming models, now comes the question of which you might want to choose. +Before we doing so, we should emphasize the the comparison made in this article may not necessary have big impact +depending on where the problems are. + +Remember [Amdahl's law](https://en.wikipedia.org/wiki/Amdahl%27s_law), if you are optimizing non performance critical +part of your problem, you won't get much of the performance gain. + +As we can see usually there is a trade-off between efficiency, flexiblity, engineering complexities. +And usually different programming styles fits into different parts of the problems. +For example, imperative programs are more natural for parameter update, and symbolic programs for gradient calculation. + +What this article advocate is to ***mix*** the flavors together. Recall Amdahl's law. Sometimes the part we want to be flexible +are not necessarily performance crucial, and it is OK to be a bit sloppy to support more flexible interfaces. +In machine learning, ensemble of different methods usually work better than a single one. + +If the programming models can be mixed together in a correct way, we could also get better benefit than a single programming model. +We will list some of the possible discussions here. + +### Symbolic and Imperative Programs +There are two ways to mix symbolic and imperative programs. +- Put imperative programs as part of symbolic programs as callbacks. +- Put symbolic programs as part of imperative programs. + +What we usually observe is that it is usually helpful to write parameter updates in an imperative way, +while the gradient calculations can be done more effectively in symbolic programs. + +The mix of programs is actually happening in existing symbolic libraries, because python itself is imperative. +For example, the following programs mixed the symbolic part together with numpy(which is imperative). +```python +A = Variable('A') +B = Variable('B') +C = B * A +D = C + Constant(1) +# compiles the function +f = compile(D) +d = f(A=np.ones(10), B=np.ones(10)*2) +d = d + 1.0 +``` +The idea is that the symbolic graphs are compiled into a function that can be executed imperatively. Whose internal is a blackbox to the user. +This is exactly like writing c++ programs and exposing them to python, which we commonly do. + +However, using numpy as imperative component might be indesirable, as the parameter memory resides on GPU. A better way might be supporting +a GPU compatible imperative library that interacts with symbolic compiled functions, or provide limited amount of updating syntax via +update statement in symbolic programs execution. + +### Small and Big Operations + +Combining small and big operations is also possible, and actually we might have a good reason to do it. Consider applications such as changing +a loss function or adding a few customized layers to an existing structure. What we usually can do is use big operations to compose up the existing +components, and use smaller operations to building up the new parts. + +Recall Amdahl's law, usually these new components may not be the bottleneck of computation. As the performance critical part is already optimized by +the bigger operations, it is even OK that we do not optimize these additional small operations at all, or only do a few memory optimization instead +of operation fusion and directly runnig them. + +### Choose your Own Flavors + +As we have compare the flavors of deep learning programs. The goal of this article is to list these choices and compare their tradeoffs. +There may not be a universal solution for all. But you can always choose your flavor, or combines the flavors you like to create +more interesting and intellegient deep learning libraries. + +Contribution to this Note +------------------------- +This note is part of our will to not only open-source system design notes for deep learning libraries. +You are more welcomed to contribute to this Note, by submitting a pull request. diff --git a/doc/python/ndarray.md b/doc/python/ndarray.md index d5cc48ee64db..6c70a938144d 100644 --- a/doc/python/ndarray.md +++ b/doc/python/ndarray.md @@ -8,9 +8,12 @@ Create NDArray Like `numpy`, you could create `mxnet.ndarray` like followings: ```python >>> import mxnet as mx ->>> a = mx.nd.zeros((100, 50)) # all-zero array of dimension 100x50 ->>> b = mx.nd.ones((256, 32, 128, 1)) # all-one array of dimension 256x32x128x1 ->>> c = mx.nd.array([[1, 2, 3], [4, 5, 6]]) # initialize array with contents +>>> # all-zero array of dimension 100x50 +>>> a = mx.nd.zeros((100, 50)) +>>> # all-one array of dimension 256x32x128x1 +>>> b = mx.nd.ones((256, 32, 128, 1)) +>>> # initialize array with contents +>>> c = mx.nd.array([[1, 2, 3], [4, 5, 6]]) ``` NDArray operations @@ -24,9 +27,11 @@ We provide some basic ndarray operations like arithmetic and slice operations. M >>> a.shape (100L, 50L) >>> b = mx.nd.ones((100, 50)) +>>> # c and d will be calculated in parallel here! >>> c = a + b ->>> d = a - b # c and d will be calculated in parallel here! ->>> b += d # inplace operation, b's contents will be modified, but c and d won't be affected. +>>> d = a - b +>>> # inplace operation, b's contents will be modified, but c and d won't be affected. +>>> b += d ``` ### Slice operations @@ -36,8 +41,8 @@ We provide some basic ndarray operations like arithmetic and slice operations. M >>> a[0:10] = 1 # first 10 rows will become 1 ``` -Conversion from/to `numpy.ndarray` and I/O --------------------------------- +Conversion from/to `numpy.ndarray` +---------------------------------- MXNet NDArray supports pretty nature way to convert from/to `mxnet.ndarray` to/from `numpy.ndarray`: ```python >>> import mxnet as mx @@ -50,13 +55,20 @@ MXNet NDArray supports pretty nature way to convert from/to `mxnet.ndarray` to/f array([ 1., 2., 3.], dtype=float32) ``` -We also provide two convenient functions to help save and load file from I/O: +Save Load NDArray +----------------- +You can always use pickle to save and load NDArrays. +We also provide functions to help save and load list or dictionary of NDArrays from file systems. ```python >>> import mxnet as mx >>> a = mx.nd.zeros((100, 200)) ->>> mx.nd.save("/path/to/array/file", a) ->>> mx.nd.save("s3://path/to/s3/array", a) ->>> mx.nd.save("hdfs://path/to/hdfs/array", a) +>>> b = mx.nd.zeros((100, 200)) +>>> # save list of NDArrays +>>> mx.nd.save("/path/to/array/file", [a, b]) +>>> # save dictionary of NDArrays to AWS S3 +>>> mx.nd.save("s3://path/to/s3/array", {'A' : a, 'B' : b}) +>>> # save list of NDArrays to hdfs. +>>> mx.nd.save("hdfs://path/to/hdfs/array", [a, b]) >>> from_file = mx.nd.load("/path/to/array/file") >>> from_s3 = mx.nd.load("s3://path/to/s3/array") >>> from_hdfs = mx.nd.load("hdfs://path/to/hdfs/array") @@ -65,8 +77,8 @@ The good thing about using the above `save` and `load` interface is that: - You could use the format across all `mxnet` language bindings. - Already support S3 and HDFS. -Multi-device support -------------------- +Multi-device Support +-------------------- The device information is stored in `mxnet.Context` structure. When creating ndarray in mxnet, user could either use the context argument (default is CPU context) to create arrays on specific device or use the `with` statement as follows: ```python >>> import mxnet as mx diff --git a/doc/python/tutorial.md b/doc/python/tutorial.md index fb818f7c5071..14d92c2bb26a 100644 --- a/doc/python/tutorial.md +++ b/doc/python/tutorial.md @@ -315,17 +315,34 @@ shape inconsistency. ### Bind the Symbols and Run -Now we can bind the free variables of the symbol and perform forward and -backward. +Now we can bind the free variables of the symbol and perform forward and backward. +The bind function will create a ```Executor``` that can be used to carry out the real computations. ```python ->>> in_shape = (128, 3, 100, 100) # minibatch_size, #channel, image_width, image_height ->>> executor = net.simple_bind(mx.gpu(), data = mx.nd.empty(in_shape, mx.gpu()) ->>> # feed data and label.. ->>> executor.forward() ->>> executor.backward() ->>> print executor.outputs[0].asnumpy() +>>> # define computation graphs +>>> A = mx.symbol.Variable('A') +>>> B = mx.symbol.Variable('B') +>>> C = A * B +>>> a = mx.nd.ones(3) * 4 +>>> b = mx.nd.ones(3) * 2 +>>> # bind the symbol with real arguments +>>> c_exec = C.bind(ctx=mx.cpu(), args={'A' : a, 'B': b}) +>>> # do forward pass calclation. +>>> c_exec.forward() +>>> c_exec.outputs[0].asnumpy() +[ 8. 8. 8.] ``` +For neural nets, a more commonly used pattern is ```simple_bind```, which will create +all the arguments arrays for you. Then you can call forward, and backward(if gradient is needed) +to get the gradient. +```python +>>> # define computation graphs +>>> net = some symbol +>>> texec = net.simple_bind(data=input_shape) +>>> texec.forward() +>>> texec.backward() +``` +The [model API](../../python/mxnet/model.py) is a thin wrapper around the symbolic executors to support neural net training. ### How Efficient is Symbolic API diff --git a/example/README.md b/example/README.md index 0765885349f8..8564e1e82d75 100644 --- a/example/README.md +++ b/example/README.md @@ -4,10 +4,10 @@ This folder contains examples of MXNet. Notebooks -------- -* [composite symbol](composite_symbol.ipynb) gives you a demo of how to composite a symbolic Inception-BatchNorm Network -* [cifar-10 recipe](cifar-recipe.ipynb) gives you a step by step demo of how to use MXNet -* [cifar-100](cifar-100.ipynb) gives you a demo of how to train a 75.68% accuracy CIFAR-100 model -* [predict with pretained model](predict-with-pretrained-model.ipynb) gives you a demo of use a pretrained Inception-BN Network +* [composite symbol](notebooks/composite_symbol.ipynb) gives you a demo of how to composite a symbolic Inception-BatchNorm Network +* [cifar-10 recipe](notebooks/cifar-recipe.ipynb) gives you a step by step demo of how to use MXNet +* [cifar-100](notebooks/cifar-100.ipynb) gives you a demo of how to train a 75.68% accuracy CIFAR-100 model +* [predict with pretained model](notebooks/predict-with-pretrained-model.ipynb) gives you a demo of use a pretrained Inception-BN Network Contents diff --git a/example/imagenet/alexnet.py b/example/imagenet/alexnet.py index e4e1663406c4..9a74631a2174 100644 --- a/example/imagenet/alexnet.py +++ b/example/imagenet/alexnet.py @@ -16,7 +16,7 @@ conv2 = mx.symbol.Convolution( data=lrn1, kernel=(5, 5), pad=(2, 2), num_filter=256) relu2 = mx.symbol.Activation(data=conv2, act_type="relu") -pool2 = mx.symbol.Pooling(data=relu2, kernel=(3, 3), stride=(2, 2)) +pool2 = mx.symbol.Pooling(data=relu2, kernel=(3, 3), stride=(2, 2), pool_type="max") lrn2 = mx.symbol.LRN(data=pool2, alpha=0.0001, beta=0.75, knorm=1, nsize=5) # stage 3 conv3 = mx.symbol.Convolution( @@ -28,7 +28,7 @@ conv5 = mx.symbol.Convolution( data=relu4, kernel=(3, 3), pad=(1, 1), num_filter=256) relu5 = mx.symbol.Activation(data=conv5, act_type="relu") -pool3 = mx.symbol.Pooling(data=relu5, kernel=(3, 3), stride=(2, 2)) +pool3 = mx.symbol.Pooling(data=relu5, kernel=(3, 3), stride=(2, 2), pool_type="max") # stage 4 flatten = mx.symbol.Flatten(data=pool3) fc1 = mx.symbol.FullyConnected(data=flatten, num_hidden=4096) @@ -48,7 +48,7 @@ train, val = ilsvrc12_iterator(batch_size=batch_size, input_shape=(3,224,224)) ## train -num_gpus = 2 +num_gpus = 4 gpus = [mx.gpu(i) for i in range(num_gpus)] model = mx.model.FeedForward( ctx = gpus, diff --git a/example/imagenet/data.py b/example/imagenet/data.py index cfca1db5e084..2f53902b3c96 100644 --- a/example/imagenet/data.py +++ b/example/imagenet/data.py @@ -7,8 +7,8 @@ def ilsvrc12_iterator(batch_size, input_shape): """return train and val iterators for imagenet""" train_dataiter = mx.io.ImageRecordIter( - path_imgrec = "data/ilsvrc12/train.rec", - mean_img = "data/ilsvrc12/mean.bin", + path_imgrec = "data/train.rec", + mean_img = "data/mean.bin", rand_crop = True, rand_mirror = True, prefetch_buffer = 4, @@ -16,8 +16,8 @@ def ilsvrc12_iterator(batch_size, input_shape): data_shape = input_shape, batch_size = batch_size) val_dataiter = mx.io.ImageRecordIter( - path_imgrec = "data/ilsvrc12/val.rec", - mean_img = "data/ilsvrc12/mean.bin", + path_imgrec = "data/val.rec", + mean_img = "data/mean.bin", rand_crop = False, rand_mirror = False, prefetch_buffer = 4, diff --git a/include/mxnet/storage.h b/include/mxnet/storage.h index 71d303ff01f3..da7a8aaa5388 100644 --- a/include/mxnet/storage.h +++ b/include/mxnet/storage.h @@ -39,16 +39,16 @@ class Storage { * \param ctx Context information about the device and ID. * \return Handle struct. */ - Handle Alloc(size_t size, Context ctx); + virtual Handle Alloc(size_t size, Context ctx) = 0; /*! * \brief Free storage. * \param handle Handle struect. */ - void Free(Handle handle); + virtual void Free(Handle handle) = 0; /*! * \brief Destructor. */ - ~Storage(); + virtual ~Storage() {} /*! * \return Storage singleton. */ @@ -62,15 +62,6 @@ class Storage { * \return A shared pointer to Storage singleton. */ static std::shared_ptr _GetSharedRef(); - - private: - /*! - * \brief Hidden constructors. - */ - Storage(); - struct Impl; - std::unique_ptr impl_; - DISALLOW_COPY_AND_ASSIGN(Storage); }; // class Storage } // namespace mxnet #endif // MXNET_STORAGE_H_ diff --git a/python/mxnet/__init__.py b/python/mxnet/__init__.py index b87b9dad924c..5792df7c4039 100644 --- a/python/mxnet/__init__.py +++ b/python/mxnet/__init__.py @@ -1,11 +1,6 @@ #!/usr/bin/env python # coding: utf-8 -"""MXNet: a concise, fast and flexible framework for deep learning - -MXNet is a project that evolves from cxxnet, minerva and purine2. -The interface is designed in collaboration by authors of three projects. - -""" +"""MXNet: a concise, fast and flexible framework for deep learning. """ from __future__ import absolute_import from .context import Context, current_context, cpu, gpu @@ -14,6 +9,7 @@ from . import ndarray from . import name from . import symbol +# use mx.kv as short for kvstore from . import kvstore as kv from . import io # use mx.nd as short for mx.ndarray diff --git a/python/mxnet/callback.py b/python/mxnet/callback.py index 4d121adf7670..03fcb5a85071 100644 --- a/python/mxnet/callback.py +++ b/python/mxnet/callback.py @@ -1,5 +1,7 @@ -# pylint: disable=logging-not-lazy, blacklisted-name, invalid-name -"""model helper for knowing training status""" +# coding: utf-8 +"""Callback functions that can be used to track various status during iteration.""" +from __future__ import absolute_import + import sys import math import logging @@ -19,11 +21,12 @@ def do_checkpoint(prefix): callback : function The callback function that can be passed as iter_end_callback to fit. """ - def _callback(iter_no, s, arg, aux): + def _callback(iter_no, sym, arg, aux): """The checkpoint function.""" - save_checkpoint(prefix, iter_no + 1, s, arg, aux) + save_checkpoint(prefix, iter_no + 1, sym, arg, aux) return _callback + class Speedometer(object): """Calculate training speed in frequent @@ -57,12 +60,13 @@ def __call__(self, count): if self.init: if count % self.frequent == 0: speed = self.frequent * self.batch_size / (time.time() - self.tic) - logging.info("Batch [%d]\tSpeed: %.2f samples/sec" % (count, speed)) + logging.info("Batch [%d]\tSpeed: %.2f samples/sec", count, speed) self.tic = time.time() else: self.init = True self.tic = time.time() + class ProgressBar(object): """Show a progress bar @@ -89,7 +93,7 @@ def __call__(self, count): filled_len = int(round(self.bar_len * count / float(self.total))) percents = math.ceil(100.0 * count / float(self.total)) - bar = '=' * filled_len + '-' * (self.bar_len - filled_len) - sys.stdout.write('[%s] %s%s\r' % (bar, percents, '%')) + prog_bar = '=' * filled_len + '-' * (self.bar_len - filled_len) + sys.stdout.write('[%s] %s%s\r' % (prog_bar, percents, '%')) diff --git a/python/mxnet/context.py b/python/mxnet/context.py index 6ed801eaa7f9..1ed3dae5fb23 100644 --- a/python/mxnet/context.py +++ b/python/mxnet/context.py @@ -1,5 +1,5 @@ # coding: utf-8 -""" code for context management """ +"""Context management API of mxnet.""" from __future__ import absolute_import class Context(object): @@ -19,7 +19,6 @@ class Context(object): Examples -------- - Switch default context example: >>> # array on cpu >>> cpu_array = mx.md.ones((2, 3)) >>> # switch default context to GPU(2) diff --git a/python/mxnet/executor.py b/python/mxnet/executor.py index 57a1ad1d238c..631204962152 100644 --- a/python/mxnet/executor.py +++ b/python/mxnet/executor.py @@ -1,6 +1,6 @@ # coding: utf-8 -# pylint: disable=invalid-name, protected-access, too-many-locals, fixme -""" code for executor. """ +# pylint: disable=invalid-name, protected-access, too-many-locals +"""Symbolic Executor component of MXNet.""" from __future__ import absolute_import import ctypes @@ -11,13 +11,17 @@ class Executor(object): """ Executor is the actual executing object of MXNet.""" - def __init__(self, handle): - """Init an executor from handle + def __init__(self, handle, symbol): + """Constructor, used Symbol.bind and Symbol.simple_bind instead. Parameters ---------- handle: ExecutorHandle ExecutorHandle generated by calling Bind + + See Also + -------- + Symbol.bind : to create executor """ if not isinstance(handle, ExecutorHandle): raise TypeError("Handle type error") @@ -26,41 +30,162 @@ def __init__(self, handle): self.grad_arrays = [] self.aux_arrays = [] self.outputs = self._get_outputs() + self._symbol = symbol + self._arg_dict = None + self._grad_dict = None + self._aux_dict = None + + @staticmethod + def _get_dict(names, ndarrays): + """Get the dictionary given name and ndarray pairs.""" + nset = set() + for nm in names: + if nm in nset: + raise ValueError('Duplicate names detected, %s' % str(names)) + nset.add(nm) + return dict(zip(names, ndarrays)) + + def _get_outputs(self): + """list all the output ndarray - def forward(self, is_train=True): - """Do forward. + Returns + ------- + A list of ndarray binded to the heads of executor. + """ + out_size = mx_uint() + handles = ctypes.POINTER(NDArrayHandle)() + check_call(_LIB.MXExecutorOutputs(self.handle, + ctypes.byref(out_size), ctypes.byref(handles))) + return [NDArray(NDArrayHandle(handles[i])) for i in range(out_size.value)] + + def forward(self, is_train=False, **kwargs): + """Calculate the outputs specified by the binded symbol. Parameters ---------- - is_train: bool - whether this forward is for evaluation purpose + is_train: bool, optional + whether this forward is for evaluation purpose. + + **kwargs + Additional specification of input arguments. + + Examples + -------- + >>> # doing forward by specifying data + >>> texec.forward(is_train=True, data=mydata) + >>> # doing forward by not specifying things, but copy to the executor before hand + >>> mydata.copyto(texec.arg_dict['data']) + >>> texec.forward(is_train=True) """ + if len(kwargs) != 0: + arg_dict = self.arg_dict + for name, array in kwargs.items(): + if not isinstance(array, NDArray): + raise ValueError('only accept keyword argument of NDArrays') + if name not in arg_dict: + raise TypeError('Unknown argument %s' % name) + array.copyto(arg_dict[name]) + check_call(_LIB.MXExecutorForward( self.handle, ctypes.c_int(int(is_train)))) - def backward(self, head_grads=None): - """Do backward on heads' gradient. + def backward(self, out_grads=None): + """Do backward pass to get the gradient of arguments. Parameters ---------- - head_grads : NDArray or list of NDArray, optional - Gradient on the heads + out_grads : NDArray or list of NDArray, optional + Gradient on the outputs to be propagated back. + This parameter is only needed when bind is called + on outputs that are not a loss function. """ - if head_grads is None: - head_grads = [] - elif isinstance(head_grads, NDArray): - head_grads = [head_grads] + if out_grads is None: + out_grads = [] + elif isinstance(out_grads, NDArray): + out_grads = [out_grads] - for obj in head_grads: + for obj in out_grads: if not isinstance(obj, NDArray): raise TypeError("inputs must be NDArray") - ndarray = c_array(NDArrayHandle, [item.handle for item in head_grads]) + ndarray = c_array(NDArrayHandle, [item.handle for item in out_grads]) check_call(_LIB.MXExecutorBackward( self.handle, - mx_uint(len(head_grads)), + mx_uint(len(out_grads)), ndarray)) + @property + def arg_dict(self): + """Get dictionary representation of argument arrrays. + + Returns + ------- + arg_dict : dict of str to NDArray + The dictionary that maps name of arguments to NDArrays. + + Raises + ------ + ValueError : if there are duplicated names in the arguments. + """ + if self._arg_dict is None: + self._arg_dict = Executor._get_dict( + self._symbol.list_arguments(), self.arg_arrays) + return self._arg_dict + + @property + def aux_dict(self): + """Get dictionary representation of auxiliary states arrays. + + Returns + ------- + aux_dict : dict of str to NDArray + The dictionary that maps name of auxiliary states to NDArrays. + + Raises + ------ + ValueError : if there are duplicated names in the auxiliary states. + """ + if self._aux_dict is None: + self._aux_dict = Executor._get_dict( + self._symbol.list_auxiliary_states(), self.aux_arrays) + return self._aux_dict + + def copy_params_from(self, arg_params, aux_params=None, allow_extra_params=False): + """Copy parameters from arg_params, aux_params into executor's internal array. + + Parameters + ---------- + arg_params : dict of str to NDArray + Parameters, dict of name to NDArray of arguments + + aux_params : dict of str to NDArray, optional + Parameters, dict of name to NDArray of auxiliary states. + + allow_extra_params : boolean, optional + Whether allow extra parameters that are not needed by symbol + If this is True, no error will be thrown when arg_params or aux_params + contain extra parameters that is not needed by the executor. + + Raises + ------ + ValueError + If there is additional parameters in the dict but allow_extra_params=False + """ + for name, array in arg_params.items(): + if name in self.arg_dict: + array.copyto(self.arg_dict[name]) + else: + if not allow_extra_params: + raise ValueError('Find name \"%s\" that is not in the arguments' % name) + if aux_params is None: + aux_params = {} + for name, array in aux_params.items(): + if name in self.aux_dict: + array.copyto(self.aux_dict[name]) + else: + if not allow_extra_params: + raise ValueError('Find name %s that is not in the auxiliary states' % name) + def debug_str(self): """Get a debug string about internal execution plan. @@ -73,16 +198,3 @@ def debug_str(self): check_call(_LIB.MXExecutorPrint( self.handle, ctypes.byref(debug_str))) return py_str(debug_str.value) - - def _get_outputs(self): - """list all heads' output ndarray - - Returns - ------- - A list of ndarray binded to the heads of executor. - """ - out_size = mx_uint() - handles = ctypes.POINTER(NDArrayHandle)() - check_call(_LIB.MXExecutorOutputs(self.handle, - ctypes.byref(out_size), ctypes.byref(handles))) - return [NDArray(NDArrayHandle(handles[i])) for i in range(out_size.value)] diff --git a/python/mxnet/kvstore.py b/python/mxnet/kvstore.py index 3d26c3cf4e50..44be43892824 100644 --- a/python/mxnet/kvstore.py +++ b/python/mxnet/kvstore.py @@ -1,6 +1,5 @@ # coding: utf-8 -# pylint: disable=invalid-name, global-statement -""" KVStore in mxnet """ +""" Key value store interface of MXNet for parameter synchronization.""" from __future__ import absolute_import import ctypes @@ -19,10 +18,10 @@ def _ctype_key_value(keys, vals): return (c_array(ctypes.c_int, [keys]), c_array(NDArrayHandle, [vals.handle])) else: - for v in vals: - assert(isinstance(v, NDArray)) + for value in vals: + assert(isinstance(value, NDArray)) return (c_array(ctypes.c_int, [keys] * len(vals)), - c_array(NDArrayHandle, [v.handle for v in vals])) + c_array(NDArrayHandle, [value.handle for value in vals])) else: assert(len(keys) == len(vals)) for k in keys: @@ -66,7 +65,7 @@ def __del__(self): def init(self, key, value): """ Initialize a single or a sequence of key-value pairs into the store. - For each key, one must init it before push and pull + For each key, one must init it before push and pull. Parameters ---------- @@ -102,8 +101,10 @@ def push(self, key, value, priority=0): ---------- key : int or list of int Keys + value : NDArray or list of NDArray or list of list of NDArray According values + priority : int, optional The priority of the push operation. The higher the priority, the faster this action is likely @@ -150,14 +151,16 @@ def push(self, key, value, priority=0): ctypes.c_int(priority))) def pull(self, key, out=None, priority=0): - """ Pull a single value or a sequence of values from the store + """ Pull a single value or a sequence of values from the store. Parameters ---------- key : int or list of int Keys + out: NDArray or list of NDArray or list of list of NDArray According values + priority : int, optional The priority of the push operation. The higher the priority, the faster this action is likely diff --git a/python/mxnet/metric.py b/python/mxnet/metric.py index e3aa483b3cd6..af9224a11292 100644 --- a/python/mxnet/metric.py +++ b/python/mxnet/metric.py @@ -1,5 +1,7 @@ -# pylint: disable=invalid-name, pointless-string-statement +# coding: utf-8 """Online evaluation metric module.""" +from __future__ import absolute_import + from .base import string_types import numpy @@ -48,10 +50,11 @@ def __init__(self): def update(self, label, pred): pred = pred.asnumpy() label = label.asnumpy().astype('int32') - py = numpy.argmax(pred, axis=1) - self.sum_metric += numpy.sum(py == label) + pred_label = numpy.argmax(pred, axis=1) + self.sum_metric += numpy.sum(pred_label == label) self.num_inst += label.size +# pylint: disable=pointless-string-statement """ class LogLoss(EvalMetric): # remove because it because it is too slow @@ -70,6 +73,7 @@ def update(self, label, pred): self.sum_metric += -numpy.log(p) self.num_inst += label.size """ +# pylint: enable=pointless-string-statement class CustomMetric(EvalMetric): """Custom evaluation metric that takes a NDArray function. @@ -94,7 +98,7 @@ def update(self, label, pred): self.sum_metric += self._feval(label, pred) self.num_inst += 1 - +# pylint: disable=invalid-name def np(numpy_feval, name=None): """Create a customized metric from numpy function. @@ -111,7 +115,7 @@ def feval(label, pred): return numpy_feval(label.asnumpy(), pred.asnumpy()) feval.__name__ = numpy_feval.__name__ return CustomMetric(feval, name) - +# pylint: enable=invalid-name def create(metric): """Create an evaluation metric. diff --git a/python/mxnet/model.py b/python/mxnet/model.py index 8f67bb14e50d..b0b4f46ccb65 100644 --- a/python/mxnet/model.py +++ b/python/mxnet/model.py @@ -1,5 +1,5 @@ # pylint: disable=fixme, invalid-name, too-many-arguments, too-many-locals -# pylint: disable=too-many-branches, too-many-statements, unused-argument +# pylint: disable=too-many-branches, too-many-statements """MXNet model module""" from __future__ import absolute_import @@ -122,6 +122,7 @@ def _train_multi_device(symbol, ctx, input_shape, begin_round, end_round, optimizer, train_data, eval_data=None, eval_metric=None, iter_end_callback=None, epoch_end_callback=None, + update_on_kvstore=False, logger=None): """Internal training function on multiple devices. @@ -172,12 +173,18 @@ def _train_multi_device(symbol, ctx, input_shape, epoch_end_callback: callable(iteration) A callback that is invoked at end of each batch + update_on_kvstore: boolean, optional + Whether to perform parameter update on kvstore instead of training device. + logger : logging logger When not specified, default logger will be used. Notes ----- - This function will inplace update the NDArrays in arg_parans and aux_states. + - This function will inplace update the NDArrays in arg_parans and aux_states. + - Turning update_on_kvstore on and off can affect speed of multi-gpu training. + - update_on_kvstore=True works well for inception type nets that contains many small weights. + - update_on_kvstore=False works better for Alexnet style net with bulk weights. """ if logger is None: logger = logging @@ -200,16 +207,14 @@ def _train_multi_device(symbol, ctx, input_shape, aux_blocks = [ [x.aux_arrays[index] for x in train_execs] for index in range(len(train_execs[0].aux_arrays))] - for name, block in zip(arg_names, arg_blocks): - if name in arg_params: - for w in block: - arg_params[name].copyto(w) - for name, block in zip(aux_names, aux_blocks): - if name in aux_params: - for w in block: - aux_params[name].copyto(w) + + for texec in train_execs: + texec.copy_params_from(arg_params, aux_params) # ky value store kv = kvstore.create() if num_device != 1 else None + if kv is None: + update_on_kvstore = False + opt_state_blocks = [] # If there are multiple devices, initialize the weights. for index, pair in enumerate(zip(arg_blocks, grad_blocks)): @@ -218,11 +223,20 @@ def _train_multi_device(symbol, ctx, input_shape, if kv: kv.init(index, arg_list[0]) # attach state direct to weight - opt_list = [optimizer.create_state(index, w) for w in arg_list] - opt_state_blocks.append(opt_list) + if update_on_kvstore: + opt_state_blocks.append(nd.zeros(arg_list[0].shape, cpu())) + else: + opt_list = [optimizer.create_state(index, w) for w in arg_list] + opt_state_blocks.append(opt_list) else: opt_state_blocks.append(None) + def kv_updater(index, grad, weight): + """Internal updater on KVstore, used when update_on_kvstore=True.""" + optimizer.update(index, weight, grad, opt_state_blocks[index]) + if update_on_kvstore: + kv.set_updater(kv_updater) + # Input and output data structure data_index, label_index = _check_arguments(symbol) merged_shape = list(train_execs[0].outputs[0].shape) @@ -246,7 +260,7 @@ def _train_multi_device(symbol, ctx, input_shape, data[islice].copyto(target) # forward backward pass for texec, islice in zip(train_execs, slices): - texec.forward() + texec.forward(is_train=True) texec.outputs[0].copyto(out_cpu_array[islice]) for texec in train_execs: texec.backward() @@ -259,12 +273,17 @@ def _train_multi_device(symbol, ctx, input_shape, if kv: # push gradient, priority is negative index kv.push(index, grad_list, priority=-index) - # pull back the sum, to the same locations. - kv.pull(index, grad_list, priority=-index) - opt_list = opt_state_blocks[index] - # optimizea - for w, g, state in zip(arg_list, grad_list, opt_list): - optimizer.update(index, w, g, state) + if update_on_kvstore: + # pull back the weights + kv.pull(index, arg_list, priority=-index) + else: + # pull back the sum gradients, to the same locations. + kv.pull(index, grad_list, priority=-index) + if not update_on_kvstore: + opt_list = opt_state_blocks[index] + # optimizea + for w, g, state in zip(arg_list, grad_list, opt_list): + optimizer.update(index, w, g, state) nbatch += 1 # epoch callback (for print purpose) if epoch_end_callback != None: @@ -566,7 +585,8 @@ def predict(self, X): return np.concatenate(outputs) def fit(self, X, y=None, eval_data=None, eval_metric='acc', - iter_end_callback=None, epoch_end_callback=None, logger=None): + iter_end_callback=None, epoch_end_callback=None, + update_on_kvstore=False, logger=None): """Fit the model. Parameters @@ -596,6 +616,9 @@ def fit(self, X, y=None, eval_data=None, eval_metric='acc', A callback that is invoked at end of each batch For print purpose + update_on_kvstore: boolean, optional + Whether to perform parameter update on kvstore instead of training device. + logger : logging logger, optional When not specified, default logger will be used. """ @@ -626,7 +649,7 @@ def fit(self, X, y=None, eval_data=None, eval_metric='acc', eval_metric=eval_metric, iter_end_callback=iter_end_callback, epoch_end_callback=epoch_end_callback, - logger=logger) + update_on_kvstore=update_on_kvstore, logger=logger) def save(self, prefix, iteration=None): """Checkpoint the model checkpoint into file. @@ -688,7 +711,7 @@ def load(prefix, iteration, ctx=None): def create(symbol, X, y=None, ctx=None, num_round=None, optimizer='sgd', initializer=Xavier(), eval_data=None, eval_metric='acc', iter_end_callback=None, - logger=None, **kwargs): + update_on_kvstore=False, logger=None, **kwargs): """Functional style to create a model. This function will be more consistent with functional @@ -730,10 +753,14 @@ def create(symbol, X, y=None, ctx=None, A callback that is invoked at end of each iteration. This can be used to checkpoint model each iteration. + update_on_kvstore: boolean, optional + Whether to perform parameter update on kvstore instead of training device. + logger : logging logger, optional """ model = FeedForward(symbol, ctx=ctx, num_round=num_round, optimizer=optimizer, initializer=initializer, **kwargs) model.fit(X, y, eval_data=eval_data, eval_metric=eval_metric, - iter_end_callback=iter_end_callback, logger=logger) + iter_end_callback=iter_end_callback, + update_on_kvstore=update_on_kvstore, logger=logger) return model diff --git a/python/mxnet/ndarray.py b/python/mxnet/ndarray.py index 5418047ee27f..642d834cf7dc 100644 --- a/python/mxnet/ndarray.py +++ b/python/mxnet/ndarray.py @@ -1,5 +1,5 @@ # coding: utf-8 -"""NDArray interface of mxnet""" +"""NDArray API of mxnet.""" from __future__ import absolute_import import ctypes @@ -355,6 +355,8 @@ def empty(shape, ctx=None): out: Array The created NDArray. """ + if isinstance(shape, int): + shape = (shape, ) if ctx is None: ctx = Context.default_ctx return NDArray(handle=_new_alloc_handle(shape, ctx, False)) @@ -438,9 +440,10 @@ def load(fname): fname : str The name of the file.Can be S3 or HDFS address (remember built with S3 support). Example of fname: - - s3://my-bucket/path/my-s3-ndarray - - hdfs://my-bucket/path/my-hdfs-ndarray - - /path-to/my-local-ndarray + + - `s3://my-bucket/path/my-s3-ndarray` + - `hdfs://my-bucket/path/my-hdfs-ndarray` + - `/path-to/my-local-ndarray` Returns ------- @@ -479,9 +482,10 @@ def save(fname, data): fname : str The name of the file.Can be S3 or HDFS address (remember built with S3 support). Example of fname: - - s3://my-bucket/path/my-s3-ndarray - - hdfs://my-bucket/path/my-hdfs-ndarray - - /path-to/my-local-ndarray + + - `s3://my-bucket/path/my-s3-ndarray` + - `hdfs://my-bucket/path/my-hdfs-ndarray` + - `/path-to/my-local-ndarray` data : list of NDArray or dict of str to NDArray The data to be saved. diff --git a/python/mxnet/symbol.py b/python/mxnet/symbol.py index e8b8af78fe3b..a52dda2fce4d 100644 --- a/python/mxnet/symbol.py +++ b/python/mxnet/symbol.py @@ -1,10 +1,6 @@ # coding: utf-8 -# pylint: disable=invalid-name, protected-access, fixme, too-many-arguments -"""Symbolic support of mxnet. - -Symbolic API of MXNet - -""" +# pylint: disable=invalid-name, protected-access, too-many-arguments +"""Symbolic configuration API of mxnet.""" from __future__ import absolute_import import ctypes @@ -571,8 +567,7 @@ def bind(self, ctx, args, args_grad=None, grad_req='write', aux_states=None): mx_uint(len(aux_states)), aux_args_handle, ctypes.byref(handle))) - executor = Executor(handle) - + executor = Executor(handle, self) executor.arg_arrays = args executor.grad_arrays = args_grad executor.aux_arrays = aux_states diff --git a/python/mxnet/visualization.py b/python/mxnet/visualization.py index 686f8cca3554..54f6c924ecdf 100644 --- a/python/mxnet/visualization.py +++ b/python/mxnet/visualization.py @@ -1,7 +1,6 @@ # coding: utf-8 -# pylint: disable=invalid-name, protected-access, too-many-locals, fixme -# pylint: disable=unused-argument, too-many-branches, too-many-statements -# pylint: disable=unused-variable +# pylint: disable=invalid-name, too-many-locals, fixme +# pylint: disable=too-many-branches, too-many-statements """Visualization module""" from __future__ import absolute_import diff --git a/scripts/travis_osx_install.sh b/scripts/travis_osx_install.sh index f0a0be48a24b..04929633ee5a 100755 --- a/scripts/travis_osx_install.sh +++ b/scripts/travis_osx_install.sh @@ -21,12 +21,11 @@ conda update -q conda # Useful for debugging any issues with conda conda info -a -if [ ${TASK} == "python-package3" ]; then +if [ ${TASK} == "package3" ]; then conda create -n myenv python=3.4 - alias python3=python else conda create -n myenv python=2.7 fi source activate myenv conda install numpy scipy matplotlib nose -python -m pip install graphviz +python -m pip install graphviz \ No newline at end of file diff --git a/scripts/travis_script.sh b/scripts/travis_script.sh index dd2a62c8b37b..a56b73c469b9 100755 --- a/scripts/travis_script.sh +++ b/scripts/travis_script.sh @@ -19,12 +19,15 @@ fi # prereqs for things that need make cp make/config.mk config.mk +export NOSE3=nosetests3 +export PYTHON3=python3 if [ ${TRAVIS_OS_NAME} == "osx" ]; then source scripts/travis_osx_install.sh echo "USE_BLAS=apple" >> config.mk echo "USE_OPENMP=0" >> config.mk - alias nosetests='python -m noise' - alias nosetests3='python -m noise' + alias nosetests='python -m nose' + export NOSE3='python -m nose' + export PYTHON3=python else echo "USE_BLAS=blas" >> config.mk echo "USE_CUDNN=0" >> config.mk @@ -46,6 +49,7 @@ fi if [ ${TASK} == "python" ]; then echo "USE_CUDA=0" >> config.mk make all || exit -1 + python --version export MXNET_ENGINE_TYPE=ThreadedEngine nosetests tests/python/unittest || exit -1 nosetests tests/python/train || exit -1 @@ -55,14 +59,16 @@ if [ ${TASK} == "python3" ]; then echo "USE_CUDA=0" >> config.mk make all || exit -1 export MXNET_ENGINE_TYPE=ThreadedEngine - nosetests3 tests/python/unittest || exit -1 - nosetests3 tests/python/train || exit -1 + ${PYTHON3} --version + ${NOSE3} tests/python/unittest || exit -1 + ${NOSE3} tests/python/train || exit -1 fi if [ ${TASK} == "python_naive" ]; then echo "USE_CUDA=0" >> config.mk make all || exit -1 export MXNET_ENGINE_TYPE=NaiveEngine + python --version nosetests tests/python/unittest || exit -1 nosetests tests/python/train || exit -1 fi @@ -71,6 +77,7 @@ if [ ${TASK} == "python_perdev" ]; then echo "USE_CUDA=0" >> config.mk make all || exit -1 export MXNET_ENGINE_TYPE=ThreadedEnginePerDevice + python --version nosetests tests/python/unittest || exit -1 nosetests tests/python/train || exit -1 fi diff --git a/src/io/image_augmenter.h b/src/io/image_augmenter.h index a9296afff3be..cd50c5e10b08 100644 --- a/src/io/image_augmenter.h +++ b/src/io/image_augmenter.h @@ -90,8 +90,7 @@ struct ImageAugmentParam : public dmlc::Parameter { class ImageAugmenter { public: // contructor - ImageAugmenter(void) - : tmpres_(false) { + ImageAugmenter(void) { #if MXNET_USE_OPENCV rotateM_ = cv::Mat(2, 3, CV_32F); #endif @@ -211,20 +210,12 @@ class ImageAugmenter { #endif private: - // temp input space - mshadow::TensorContainer tmpres_; - // mean image - mshadow::TensorContainer meanimg_; - /*! \brief temp space */ - mshadow::TensorContainer img_; #if MXNET_USE_OPENCV // temporal space cv::Mat temp_; // rotation param cv::Mat rotateM_; - // whether the mean file is ready #endif - bool meanfile_ready_; // parameters ImageAugmentParam param_; /*! \brief list of possible rotate angle */ diff --git a/src/io/iter_batchloader.h b/src/io/iter_batchloader.h index 00f46ec4f721..1ec28d13d65a 100644 --- a/src/io/iter_batchloader.h +++ b/src/io/iter_batchloader.h @@ -56,8 +56,7 @@ class BatchLoader : public IIterator { std::vector > kwargs_left; // init batch param, it could have similar param with kwargs_left = param_.InitAllowUnknown(kwargs); - // init base iterator - base_->Init(kwargs); + // init object attributes std::vector data_shape_vec; data_shape_vec.push_back(param_.batch_size); for (size_t shape_dim = 0; shape_dim < param_.data_shape.ndim(); ++shape_dim) { @@ -75,6 +74,8 @@ class BatchLoader : public IIterator { label_holder_ = mshadow::NewTensor(label_shape_.get<2>(), 0.0f); out_.data.push_back(TBlob(data_holder_)); out_.data.push_back(TBlob(label_holder_)); + // init base iterator + base_->Init(kwargs); } inline void BeforeFirst(void) { if (param_.round_batch == 0 || num_overflow_ == 0) { diff --git a/src/storage/pooled_storage_manager.h b/src/storage/pooled_storage_manager.h index 92cbc55951a3..7d3c0dcb7802 100644 --- a/src/storage/pooled_storage_manager.h +++ b/src/storage/pooled_storage_manager.h @@ -6,10 +6,11 @@ #ifndef MXNET_STORAGE_POOLED_STORAGE_MANAGER_H_ #define MXNET_STORAGE_POOLED_STORAGE_MANAGER_H_ +#include #include #include -#include "storage_manager.h" -#include "mxnet/base.h" +#include +#include "./storage_manager.h" namespace mxnet { namespace storage { @@ -35,13 +36,18 @@ class PooledStorageManager final : public StorageManager { private: void ReleaseAll(); + // internal mutex + std::mutex mutex_; + // used memory size_t used_memory_ = 0; + // memory pool std::unordered_map> memory_pool_; DISALLOW_COPY_AND_ASSIGN(PooledStorageManager); }; // class PooledStorageManager template void* PooledStorageManager::Alloc(size_t size) { + std::lock_guard lock(mutex_); auto&& reuse_it = memory_pool_.find(size); if (reuse_it == memory_pool_.end() || reuse_it->second.size() == 0) { if (kThreshold <= used_memory_) { @@ -60,6 +66,7 @@ void* PooledStorageManager::Alloc(size_t size) { template void PooledStorageManager::Free(void* ptr, size_t size) { + std::lock_guard lock(mutex_); auto&& reuse_pool = memory_pool_[size]; reuse_pool.push_back(ptr); } diff --git a/src/storage/storage.cc b/src/storage/storage.cc index 08af99621b40..4e9c85b71f74 100644 --- a/src/storage/storage.cc +++ b/src/storage/storage.cc @@ -5,21 +5,25 @@ #include #include #include -#include -#include -#include "storage_manager.h" -#include "naive_storage_manager.h" -#include "pooled_storage_manager.h" -#include "cpu_device_storage.h" -#include "gpu_device_storage.h" -#include "pinned_memory_storage.h" +#include "./storage_manager.h" +#include "./naive_storage_manager.h" +#include "./pooled_storage_manager.h" +#include "./cpu_device_storage.h" +#include "./gpu_device_storage.h" +#include "./pinned_memory_storage.h" #include "../common/cuda_utils.h" -#include "../common/utils.h" +#include "../common/lazy_alloc_array.h" namespace mxnet { // consider change storage as a pure abstract class -struct Storage::Impl { +class StorageImpl : public Storage { + public: + Handle Alloc(size_t size, Context ctx) override; + void Free(Handle handle) override; + virtual ~StorageImpl() = default; + + private: static constexpr size_t kPoolThreshold = 4096 * 1024 * 1024ul; static constexpr size_t kMaxNumberOfDevices = Context::kMaxDevType + 1; static constexpr size_t kMaxNumberOfDeviceIDs = Context::kMaxDevID + 1; @@ -43,64 +47,56 @@ struct Storage::Impl { LOG(FATAL) << "Unimplemented device"; } } - - std::array, - kMaxNumberOfDeviceIDs>, - kMaxNumberOfDevices> storage_managers; - std::mutex m; + // internal storage managers + std::array, + kMaxNumberOfDevices> storage_managers_; }; // struct Storage::Impl -Storage::Handle Storage::Alloc(size_t size, Context ctx) { +Storage::Handle StorageImpl::Alloc(size_t size, Context ctx) { // space already recycled, ignore request Handle hd; hd.ctx = ctx; hd.size = size; - { - std::lock_guard lock{impl_->m}; - auto&& device = impl_->storage_managers.at(ctx.dev_type); - auto&& device_id_it = device.at(ctx.dev_id); - // Allocate device if necessary. - if (!device_id_it) { - switch (ctx.dev_type) { - case Context::kCPU: { - device_id_it = common::MakeUnique< - Storage::Impl::CurrentStorageManager< - storage::CPUDeviceStorage>>(); - break; - } - case Context::kCPUPinned: { - device_id_it = common::MakeUnique< - Storage::Impl::CurrentStorageManager< - storage::PinnedMemoryStorage>>(); - break; + auto&& device = storage_managers_.at(ctx.dev_type); + storage::StorageManager *manager = device.Get( + ctx.dev_id, [ctx]() { + storage::StorageManager *ptr = nullptr; + switch (ctx.dev_type) { + case Context::kCPU: { + ptr = new CurrentStorageManager(); + break; + } + case Context::kCPUPinned: { + ptr = new CurrentStorageManager(); + break; + } + case Context::kGPU: { + ptr = new CurrentStorageManager(); + break; + } + default: LOG(FATAL) << "Unimplemented device"; } - case Context::kGPU: { - device_id_it = common::MakeUnique>(); - break; - } - default: - LOG(FATAL) << "Unimplemented device"; - } - } - Impl::ActivateDevice(ctx); - hd.dptr = device_id_it->Alloc(size); - } + return ptr; + }); + this->ActivateDevice(ctx); + hd.dptr = manager->Alloc(size); return hd; } -void Storage::Free(Storage::Handle handle) { - std::lock_guard lock{impl_->m}; - Impl::ActivateDevice(handle.ctx); - impl_->storage_managers.at(handle.ctx.dev_type) - .at(handle.ctx.dev_id) - ->Free(handle.dptr, handle.size); +void StorageImpl::Free(Storage::Handle handle) { + const Context &ctx = handle.ctx; + auto&& device = storage_managers_.at(ctx.dev_type); + storage::StorageManager *maneger = device.Get( + ctx.dev_id, []() { + LOG(FATAL) << "Cannot Free space to a device you have not allocated"; + return nullptr; + }); + this->ActivateDevice(ctx); + maneger->Free(handle.dptr, handle.size); } -Storage::~Storage() = default; - std::shared_ptr Storage::_GetSharedRef() { - static std::shared_ptr inst(new Storage()); + static std::shared_ptr inst(new StorageImpl()); return inst; } @@ -108,8 +104,4 @@ Storage* Storage::Get() { static Storage *ptr = _GetSharedRef().get(); return ptr; } - -Storage::Storage() : impl_{new Impl{}} {} - - } // namespace mxnet diff --git a/tests/python/train/test_mlp.py b/tests/python/train/test_mlp.py index 14e8f6c700a8..fdf89142fc11 100644 --- a/tests/python/train/test_mlp.py +++ b/tests/python/train/test_mlp.py @@ -53,7 +53,8 @@ def test_mlp(): ctx=[mx.cpu(i) for i in range(2)], num_round=num_round, learning_rate=0.01, wd=0.0004, - momentum=0.9) + momentum=0.9, + update_on_kvstore=True) logging.info('Finish traning...') prob = model.predict(val_dataiter) diff --git a/tests/python/unittest/test_io.py b/tests/python/unittest/test_io.py index ed9ce358f24a..5ece7bac8023 100644 --- a/tests/python/unittest/test_io.py +++ b/tests/python/unittest/test_io.py @@ -85,6 +85,6 @@ def test_NDArrayIter(): assert(labelcount[i] == 100) if __name__ == "__main__": - test_NumpyIter() + #test_NDArrayIter() #test_MNISTIter() - #test_Cifar10Rec() + test_Cifar10Rec()