From 3e8e1e063f33cbd5cbf844d630fe7f1648bd7957 Mon Sep 17 00:00:00 2001 From: Bing Xu Date: Tue, 8 Sep 2015 16:36:18 -0600 Subject: [PATCH 01/13] simple bind --- example/cifar10/cifar10.py | 59 ++++++++++++++-------------------- python/mxnet/__init__.py | 2 +- python/mxnet/context.py | 30 +++++++++++++++++ python/mxnet/executor.py | 28 ++++++++++++++++ python/mxnet/narray.py | 10 ++---- python/mxnet/symbol.py | 42 +++++++++++++++++++++++- src/operator/batch_norm-inl.h | 4 +++ src/operator/convolution-inl.h | 34 +++++++++++++------- tests/python/test_conv.py | 4 +-- 9 files changed, 155 insertions(+), 58 deletions(-) diff --git a/example/cifar10/cifar10.py b/example/cifar10/cifar10.py index 10a7c40eea03..278c10e3f5c5 100644 --- a/example/cifar10/cifar10.py +++ b/example/cifar10/cifar10.py @@ -5,7 +5,7 @@ import sys sys.path.append("../../tests/python") import get_data - +import time """ CXXNET Result: @@ -70,8 +70,8 @@ def ConvFactory(**kwargs): param = copy.copy(kwargs) act = param["act_type"] del param["act_type"] + param["workspace"] = 512 param["name"] = "conv%d" % conv_cnt - param["nstep"] = 64 conv = mx.symbol.Convolution(**param) bn = mx.symbol.BatchNorm(data = conv, name="bn%d" % conv_cnt) relu = mx.symbol.Activation(data = bn, name = "%s%d" % (act, conv_cnt), act_type=act) @@ -89,13 +89,11 @@ def DownsampleFactory(data, ch_3x3, stride = 2): param["num_filter"] = ch_3x3 param["act_type"] = "relu" param["data"] = data - param["nstep"] = 100 param["pad"] = (1, 1) conv3x3 = ConvFactory(**param) # pool del param["num_filter"] del param["act_type"] - del param["nstep"] del param["pad"] param["pool_type"] = "max" param["name"] = "pool%d" % pool_cnt @@ -117,7 +115,6 @@ def SimpleFactory(data, ch_1x1, ch_3x3): param["stride"] = (1, 1) param["act_type"] = "relu" param["data"] = data - param["nstep"] = 128 conv1x1 = ConvFactory(**param) # 3x3 @@ -143,7 +140,7 @@ def RandomInit(narray): in3a = SimpleFactory(conv1, 32, 32) in3b = SimpleFactory(in3a, 32, 48) in3c = DownsampleFactory(in3b, 80) -in4a = SimpleFactory(in3c, 112, 38) +in4a = SimpleFactory(in3c, 112, 48) in4b = SimpleFactory(in4a, 96, 64) in4c = SimpleFactory(in4b, 80, 80) in4d = SimpleFactory(in4c, 48, 96) @@ -155,27 +152,30 @@ def RandomInit(narray): fc = mx.symbol.FullyConnected(data=flatten, num_hidden=10, name="fc1") loss = mx.symbol.Softmax(data=fc, name="sm") -args_list = loss.list_arguments() +epoch = 9 +lr = 0.05 +wd = 0.0001 +momentum = 0.9 batch_size = 128 data_shape = (batch_size, 3, 28, 28) -arg_shapes, out_shapes, aux_shapes = loss.infer_shape(data=data_shape) -arg_narrays = [mx.narray.zeros(shape, ctx=mx.Context("gpu")) for shape in arg_shapes] -grad_narrays = [mx.narray.zeros(shape, ctx=mx.Context("gpu")) for shape in arg_shapes] -mom_narrays = [mx.narray.zeros(shape, ctx=mx.Context("gpu")) for shape in arg_shapes] -aux_narrays = [mx.narray.zeros(shape, ctx=mx.Context("gpu")) for shape in aux_shapes] +in_data = mx.narray.empty(data_shape, mx.gpu()) +executor = loss.simple_bind(mx.gpu(), {"data": in_data}) +out_narray = executor.heads()[0] +pred = mx.narray.zeros(out_narray.shape) -inputs = dict(zip(args_list, arg_narrays)) +arg_narrays, grad_narrays = executor.list_arguments() +momentum_narrays = [mx.narray.zeros(item.shape, mx.gpu()) for item in grad_narrays] -name2shape = dict(zip(args_list, arg_shapes)) -pred = mx.narray.zeros(out_shapes[0]) +inputs = dict(zip(loss.list_arguments(), arg_narrays)) +block = zip(grad_narrays, arg_narrays, momentum_narrays) np.random.seed(0) # set random weight -for name, narray in inputs.items(): +for name, narray in zip(loss.list_arguments(), arg_narrays): if "weight" in name: narray[:] = np.random.uniform(-0.1, 0.1, narray.shape) if "bias" in name: @@ -185,25 +185,11 @@ def RandomInit(narray): if "beta" in name: narray[:] = 0.0 -# bind executer -# TODO(bing): think of a better bind interface -executor = loss.bind(mx.Context('gpu'), arg_narrays, grad_narrays, 'write', aux_narrays) -# update - -out_narray = executor.heads()[0] - -epoch = 9 -lr = 0.05 -wd = 0.0001 -momentum = 0.9 - def Update(grad, weight, mom): mom[:] *= momentum mom[:] += -lr * (grad / batch_size + wd * weight) weight[:] += mom -block = list(zip(grad_narrays, arg_narrays, mom_narrays)) - #check data get_data.GetCifar10() @@ -224,15 +210,17 @@ def Update(grad, weight, mom): batch_size=batch_size, nthread=1) -tmp_label = mx.narray.zeros(name2shape["sm_label"]) +tmp_label = mx.narray.zeros(inputs["sm_label"].shape) -def progress(count, total, suffix=''): - bar_len = 80 +def progress(count, total, epoch, toc): + bar_len = 60 filled_len = int(round(bar_len * count / float(total))) percents = round(100.0 * count / float(total), 1) bar = '=' * filled_len + '-' * (bar_len - filled_len) - + tic = time.time() + speed = batch_size / float(tic - toc) + suffix = "Epoch %d, Speed: %.2f pic/sec" % (epoch, speed) sys.stdout.write('[%s] %s%s ...%s\r' % (bar, percents, '%', suffix)) def test_cifar(): @@ -247,7 +235,7 @@ def test_cifar(): val_nbatch = 0 all_train_bacth = 50000 / float(batch_size) for data, label in train_dataiter: - progress(train_nbatch, all_train_bacth, "Epoch %d" % i) + toc = time.time() label = label.asnumpy().flatten() tmp_label[:] = label inputs["data"][:] = data @@ -260,6 +248,7 @@ def test_cifar(): for grad, weight, mom in block: Update(grad, weight, mom) + progress(train_nbatch, all_train_bacth, i, toc) # evaluate for data, label in test_dataiter: diff --git a/python/mxnet/__init__.py b/python/mxnet/__init__.py index a8632bfa2ff8..4a0c62bccab3 100644 --- a/python/mxnet/__init__.py +++ b/python/mxnet/__init__.py @@ -8,7 +8,7 @@ """ from __future__ import absolute_import -from .context import Context, current_context +from .context import Context, current_context, cpu, gpu from .base import MXNetError from . import narray from . import symbol diff --git a/python/mxnet/context.py b/python/mxnet/context.py index 7a043bf5c4b9..9d84f6915fbb 100644 --- a/python/mxnet/context.py +++ b/python/mxnet/context.py @@ -52,6 +52,36 @@ def __exit__(self, ptype, value, trace): # initialize the default context in Context Context.default_ctx = Context('cpu', 0) +def cpu(device_id=0): + """ + Return CPU context + + Parameters + ---------- + device_id : int (default=0) + the device id of the device, needed for GPU + + Returns + --------- + A cpu context + """ + return Context('cpu', device_id) + +def gpu(device_id=0): + """ + Return CPU context + + Parameters + ---------- + device_id : int (default=0) + the device id of the device, needed for GPU + + Returns + --------- + A cpu context + """ + return Context('gpu', device_id) + def current_context(): """Return the current context. diff --git a/python/mxnet/executor.py b/python/mxnet/executor.py index 17df30190ce8..312ae8edb6c6 100644 --- a/python/mxnet/executor.py +++ b/python/mxnet/executor.py @@ -22,6 +22,34 @@ def __init__(self, handle): if not isinstance(handle, ExecutorHandle): raise TypeError("Handle type error") self.handle = handle + self.arg_narrays = [] + self.grad_narrays = [] + self.auxiliary_states = [] + + def list_arguments(self, with_grad=True): + """Return arguments (and grad for arguments) + + Parameters + ---------- + with_grad: bool + whether return args with grad + + Returns + ------- + if with_grad = True, return (args, grad) pair list + otherwise return args list only + Note: args sequence is same to symbol.list_arguments() + """ + if with_grad: + return self.arg_narrays, self.grad_narrays + else: + return self.arg_narrays + + def list_auxiliary_states(): + """Return auxiliary states of executor + Note: auxiliary states is same to symbol.list_auxiliary_states() + """ + return self.auxiliary_states def forward(self, is_train=True): """Do forward. diff --git a/python/mxnet/narray.py b/python/mxnet/narray.py index acc05d08d546..208fd8e17d7a 100644 --- a/python/mxnet/narray.py +++ b/python/mxnet/narray.py @@ -349,9 +349,7 @@ def zeros(shape, ctx=None): out: Array The created NArray. """ - if ctx is None: - ctx = Context.default_ctx - arr = NArray(handle=_new_alloc_handle(shape, ctx, False)) + arr = empty(shape, ctx) arr[:] = 0.0 return arr @@ -371,15 +369,11 @@ def ones(shape, ctx=None): out: Array The created NArray. """ - if ctx is None: - ctx = Context.default_ctx - arr = NArray(handle=_new_alloc_handle(shape, ctx, False)) + arr = empty(shape, ctx) arr[:] = 1.0 return arr - - def array(source_array, ctx=None): """Create a new NArray that copies content from source_array. diff --git a/python/mxnet/symbol.py b/python/mxnet/symbol.py index f882933538b2..465831ed5a88 100644 --- a/python/mxnet/symbol.py +++ b/python/mxnet/symbol.py @@ -10,7 +10,7 @@ from .base import NArrayHandle, ExecutorHandle, SymbolHandle from .base import check_call from .context import Context -from .narray import NArray +from .narray import NArray, zeros from .executor import Executor @@ -332,6 +332,46 @@ def _get_narray_handle(arg_key, args, arg_names, allow_missing): raise TypeError('Only Accept list of NArrays or dict of str->NArray') return c_array(NArrayHandle, arg_handles) + def simple_bind(self, ctx, args, grad_req='write'): + """Simply bind current symbol to get an executor + Parameters + ---------- + ctx : Context + The device context the generated executor to run on. + + args : list of NArray or dict of str->NArray + Input arguments to the symbol. + - type is dict of str->NArray, then it maps the name of arguments + to the corresponding NArray, + - Not all the arguments must be provided. + Returns + ------- + executor : mxnet.Executor + The generated Executor + """ + if not isinstance(args, dict): + raise TypeError("args must be dict of str->NArray") + input_shapes = dict((arr[0], arr[1].shape) for arr in args.items()) + arg_shapes, out_shapes, aux_shapes = self.infer_shape(**input_shapes) + if arg_shapes == None: + raise ValueError("Input node is not complete") + # alloc space + arg_narrays = [] + for name, shape in zip(self.list_arguments(), arg_shapes): + if name in args: + arg_narrays.append(args[name]) + else: + arg_narrays.append(zeros(shape, ctx)) + # TODO(bing): specail treat input data grad + grad_narrays = [zeros(shape, ctx) for shape in arg_shapes] + aux_narrays = [zeros(shape, ctx) for shape in aux_shapes] + executor = self.bind(ctx, arg_narrays, grad_narrays, grad_req, aux_narrays) + executor.arg_narrays = arg_narrays + executor.grad_narrays = grad_narrays + executor.auxiliary_states = aux_narrays + + return executor + def bind(self, ctx, args, args_grad=None, grad_req='write', aux_states=None): """Bind current symbol to get an executor. diff --git a/src/operator/batch_norm-inl.h b/src/operator/batch_norm-inl.h index 0f3b303f85b6..1409615e853a 100644 --- a/src/operator/batch_norm-inl.h +++ b/src/operator/batch_norm-inl.h @@ -261,6 +261,10 @@ class BatchNormProp : public OperatorProperty { Operator* CreateOperator(Context ctx) const; + std::vector BackwardResource() const override { + return {Resource::kTempSpace}; + } + private: BatchNormParam param_; }; // class BatchNormProp diff --git a/src/operator/convolution-inl.h b/src/operator/convolution-inl.h index b313aab14f94..595ee375f68a 100644 --- a/src/operator/convolution-inl.h +++ b/src/operator/convolution-inl.h @@ -30,7 +30,7 @@ struct ConvolutionParam : public dmlc::Parameter { TShape pad; uint32_t num_filter; uint32_t num_group; - uint32_t nstep; + uint32_t workspace; bool no_bias; DMLC_DECLARE_PARAMETER(ConvolutionParam) { int shape[] = {1, 1}; @@ -44,8 +44,8 @@ struct ConvolutionParam : public dmlc::Parameter { .describe("convolution filter(channel) number"); DMLC_DECLARE_FIELD(num_group).set_default(1) .describe("number of groups partition"); - DMLC_DECLARE_FIELD(nstep).set_default(2).set_range(1, 10000) - .describe("process n images once"); + DMLC_DECLARE_FIELD(workspace).set_default(128).set_range(1, 10000) + .describe("Tmp workspace for convolution (MB)"); DMLC_DECLARE_FIELD(no_bias).set_default(false) .describe("Whether to disable bias parameter."); } @@ -80,8 +80,8 @@ class ConvolutionOp : public Operator { Tensor out = out_data[kOut].get(s); this->InitTemp(ctx, data.shape_, out.shape_); const index_t nbatch = data.size(0); - for (index_t i = 0; i < nbatch; i += param_.nstep) { - const index_t step = std::min(param_.nstep, nbatch - i); + for (index_t i = 0; i < nbatch; i += nstep_) { + const index_t step = std::min(nstep_, nbatch - i); temp_col_.Resize(mshadow::Shape2(shape_colunit_[0], shape_colunit_[1] * step)); temp_dst_.Resize(mshadow::Shape3(shape_dstunit_[0], @@ -148,8 +148,8 @@ class ConvolutionOp : public Operator { Tensor gwmat = in_grad[kWeight].get_with_shape(wmat_shape, s); this->InitTemp(ctx, data.shape_, grad.shape_); const index_t nbatch = data.size(0); - for (index_t i = 0; i < nbatch; i += param_.nstep) { - const index_t step = std::min(param_.nstep, nbatch - i); + for (index_t i = 0; i < nbatch; i += nstep_) { + const index_t step = std::min(nstep_, nbatch - i); temp_col_.Resize(Shape2(shape_colunit_[0], shape_colunit_[1] * step)); temp_dst_.Resize(Shape3(shape_dstunit_[0], @@ -220,16 +220,19 @@ class ConvolutionOp : public Operator { shape_dstunit_ = mshadow::Shape3(param_.num_group, param_.num_filter / param_.num_group, oshape[2] * oshape[3]); - int nop = (ishape[0] + param_.nstep - 1) / param_.nstep; - param_.nstep = (ishape[0] + nop - 1) / nop; + const uint32_t workspace_size = param_.workspace << 18; + nstep_ = std::max(std::min(static_cast(workspace_size / shape_colunit_.Size()), + ishape[0]), 1U); + int nop = (ishape[0] + nstep_ - 1) / nstep_; + nstep_ = (ishape[0] + nop - 1) / nop; mshadow::Stream *s = ctx.get_stream(); temp_col_.set_stream(s); temp_dst_.set_stream(s); temp_col_.Resize(mshadow::Shape2(shape_colunit_[0], - shape_colunit_[1] * param_.nstep)); + shape_colunit_[1] * nstep_)); temp_dst_.Resize(mshadow::Shape3(shape_dstunit_[0], shape_dstunit_[1], - shape_dstunit_[2] * param_.nstep)); + shape_dstunit_[2] * nstep_)); } ConvolutionParam param_; @@ -238,6 +241,7 @@ class ConvolutionOp : public Operator { mshadow::TensorContainer temp_dst_; mshadow::Shape<2> shape_colunit_; mshadow::Shape<3> shape_dstunit_; + index_t nstep_; }; // class ConvolutionOp template @@ -328,6 +332,14 @@ class ConvolutionProp : public OperatorProperty { Operator* CreateOperator(Context ctx) const; + std::vector ForwardResource() const override { + return {Resource::kTempSpace}; + } + + std::vector BackwardResource() const override { + return {Resource::kTempSpace}; + } + private: ConvolutionParam param_; }; // class ConvolutionProp diff --git a/tests/python/test_conv.py b/tests/python/test_conv.py index 9ab34ce1c8ae..d63a0542ce7a 100644 --- a/tests/python/test_conv.py +++ b/tests/python/test_conv.py @@ -12,12 +12,12 @@ def CalAcc(out, label): # symbol net batch_size = 100 data = mx.symbol.Variable('data') -conv1= mx.symbol.Convolution(data = data, name='conv1', num_filter=32, kernel=(3,3), stride=(2,2), nstep=100) +conv1= mx.symbol.Convolution(data = data, name='conv1', num_filter=32, kernel=(3,3), stride=(2,2)) bn1 = mx.symbol.BatchNorm(data = conv1, name="bn1") act1 = mx.symbol.Activation(data = bn1, name='relu1', act_type="relu") mp1 = mx.symbol.Pooling(data = act1, name = 'mp1', kernel=(2,2), stride=(2,2), pool_type='max') -conv2= mx.symbol.Convolution(data = mp1, name='conv2', num_filter=32, kernel=(3,3), stride=(2,2), nstep=100) +conv2= mx.symbol.Convolution(data = mp1, name='conv2', num_filter=32, kernel=(3,3), stride=(2,2)) bn2 = mx.symbol.BatchNorm(data = conv2, name="bn2") act2 = mx.symbol.Activation(data = bn2, name='relu2', act_type="relu") mp2 = mx.symbol.Pooling(data = act2, name = 'mp2', kernel=(2,2), stride=(2,2), pool_type='max') From b0e515acedfe07fcd6a791cd34ca15644c5046e9 Mon Sep 17 00:00:00 2001 From: Bing Xu Date: Wed, 9 Sep 2015 20:16:18 -0600 Subject: [PATCH 02/13] cudnn conv --- Makefile | 2 +- example/cifar10/cifar10.py | 11 +- include/mxnet/base.h | 2 +- mshadow | 2 +- python/mxnet/symbol.py | 2 +- src/dag_engine/naive_engine.cc | 7 +- src/operator/convolution-inl.h | 14 +- src/operator/convolution.cu | 6 + src/operator/cudnn_convolution-inl.h | 272 +++++++++++++++++++++++++++ src/operator/fully_connected-inl.h | 8 + 10 files changed, 314 insertions(+), 12 deletions(-) create mode 100644 src/operator/cudnn_convolution-inl.h diff --git a/Makefile b/Makefile index 1bbfc12655a5..ea72c4dc3a1d 100644 --- a/Makefile +++ b/Makefile @@ -66,7 +66,7 @@ ifeq ($(USE_OPENMP_ITER), 1) endif ifeq ($(USE_CUDNN), 1) - CFLAGS += -DCXXNET_USE_CUDNN=1 + CFLAGS += -DMSHADOW_USE_CUDNN=1 LDFLAGS += -lcudnn endif diff --git a/example/cifar10/cifar10.py b/example/cifar10/cifar10.py index 278c10e3f5c5..a937060f6520 100644 --- a/example/cifar10/cifar10.py +++ b/example/cifar10/cifar10.py @@ -70,7 +70,7 @@ def ConvFactory(**kwargs): param = copy.copy(kwargs) act = param["act_type"] del param["act_type"] - param["workspace"] = 512 + param["workspace"] = 256 param["name"] = "conv%d" % conv_cnt conv = mx.symbol.Convolution(**param) bn = mx.symbol.BatchNorm(data = conv, name="bn%d" % conv_cnt) @@ -175,7 +175,7 @@ def RandomInit(narray): np.random.seed(0) # set random weight -for name, narray in zip(loss.list_arguments(), arg_narrays): +for name, narray in inputs.items(): if "weight" in name: narray[:] = np.random.uniform(-0.1, 0.1, narray.shape) if "bias" in name: @@ -213,7 +213,7 @@ def Update(grad, weight, mom): tmp_label = mx.narray.zeros(inputs["sm_label"].shape) def progress(count, total, epoch, toc): - bar_len = 60 + bar_len = 50 filled_len = int(round(bar_len * count / float(total))) percents = round(100.0 * count / float(total), 1) @@ -223,6 +223,7 @@ def progress(count, total, epoch, toc): suffix = "Epoch %d, Speed: %.2f pic/sec" % (epoch, speed) sys.stdout.write('[%s] %s%s ...%s\r' % (bar, percents, '%', suffix)) + def test_cifar(): acc_train = 0. acc_val = 0. @@ -233,8 +234,10 @@ def test_cifar(): val_acc = 0.0 train_nbatch = 0 val_nbatch = 0 - all_train_bacth = 50000 / float(batch_size) + all_train_bacth = round(50000 / float(batch_size) + 1) for data, label in train_dataiter: + if train_nbatch > 30: + break toc = time.time() label = label.asnumpy().flatten() tmp_label[:] = label diff --git a/include/mxnet/base.h b/include/mxnet/base.h index a7a3a8063a92..e3fbe002fdfc 100644 --- a/include/mxnet/base.h +++ b/include/mxnet/base.h @@ -30,7 +30,7 @@ *\brief whether to use cudnn library for convolution */ #ifndef MXNET_USE_CUDNN -#define MXNET_USE_CUDNN 0 +#define MXNET_USE_CUDNN MSHADOW_USE_CUDNN #endif /*! \brief namespace of mxnet */ diff --git a/mshadow b/mshadow index 3053f8cdfea0..208a198213ea 160000 --- a/mshadow +++ b/mshadow @@ -1 +1 @@ -Subproject commit 3053f8cdfea0274739282ced015ad458090760e8 +Subproject commit 208a198213ea011e42f91b128b14a7206cce62a5 diff --git a/python/mxnet/symbol.py b/python/mxnet/symbol.py index 465831ed5a88..e64b0d8e8253 100644 --- a/python/mxnet/symbol.py +++ b/python/mxnet/symbol.py @@ -351,7 +351,7 @@ def simple_bind(self, ctx, args, grad_req='write'): """ if not isinstance(args, dict): raise TypeError("args must be dict of str->NArray") - input_shapes = dict((arr[0], arr[1].shape) for arr in args.items()) + input_shapes = dict((name, arr.shape) for name, arr in args.items()) arg_shapes, out_shapes, aux_shapes = self.infer_shape(**input_shapes) if arg_shapes == None: raise ValueError("Input node is not complete") diff --git a/src/dag_engine/naive_engine.cc b/src/dag_engine/naive_engine.cc index bffeb474bfa6..1cf1c07c5a62 100644 --- a/src/dag_engine/naive_engine.cc +++ b/src/dag_engine/naive_engine.cc @@ -9,9 +9,14 @@ class NaiveEngine : public DAGEngine { public: NaiveEngine() { #if MXNET_USE_CUDA + #if MXNET_USE_CUDNN + LOG(INFO) << "MXNET USE CUDNN"; + stream_ = mshadow::NewStream(true, true); + #else stream_ = mshadow::NewStream(true, false); + #endif // MXNET_USE_CUDNN ctx_.stream = stream_; - #endif + #endif // MXNET_USE_CUDA } ~NaiveEngine() { diff --git a/src/operator/convolution-inl.h b/src/operator/convolution-inl.h index 595ee375f68a..f8a29b204d60 100644 --- a/src/operator/convolution-inl.h +++ b/src/operator/convolution-inl.h @@ -30,7 +30,7 @@ struct ConvolutionParam : public dmlc::Parameter { TShape pad; uint32_t num_filter; uint32_t num_group; - uint32_t workspace; + uint64_t workspace; bool no_bias; DMLC_DECLARE_PARAMETER(ConvolutionParam) { int shape[] = {1, 1}; @@ -78,6 +78,10 @@ class ConvolutionOp : public Operator { TShape wmat_shape(ws, ws + 3); Tensor wmat = in_data[kWeight].get_with_shape(wmat_shape, s); Tensor out = out_data[kOut].get(s); + #if defined(__CUDACC__) + CHECK_EQ(s->blas_handle_ownership_, Stream::OwnHandle) + << "Must init CuBLAS handle in stream"; + #endif this->InitTemp(ctx, data.shape_, out.shape_); const index_t nbatch = data.size(0); for (index_t i = 0; i < nbatch; i += nstep_) { @@ -146,6 +150,10 @@ class ConvolutionOp : public Operator { Tensor grad = out_grad[kOut].get(s); Tensor gdata = in_grad[kData].get(s); Tensor gwmat = in_grad[kWeight].get_with_shape(wmat_shape, s); + #if defined(__CUDACC__) + CHECK_EQ(s->blas_handle_ownership_, Stream::OwnHandle) + << "Must init CuBLAS handle in stream"; + #endif this->InitTemp(ctx, data.shape_, grad.shape_); const index_t nbatch = data.size(0); for (index_t i = 0; i < nbatch; i += nstep_) { @@ -220,8 +228,8 @@ class ConvolutionOp : public Operator { shape_dstunit_ = mshadow::Shape3(param_.num_group, param_.num_filter / param_.num_group, oshape[2] * oshape[3]); - const uint32_t workspace_size = param_.workspace << 18; - nstep_ = std::max(std::min(static_cast(workspace_size / shape_colunit_.Size()), + const uint64_t workspace_size = param_.workspace << 20; + nstep_ = std::max(std::min(static_cast(workspace_size / shape_colunit_.Size()), ishape[0]), 1U); int nop = (ishape[0] + nstep_ - 1) / nstep_; nstep_ = (ishape[0] + nop - 1) / nop; diff --git a/src/operator/convolution.cu b/src/operator/convolution.cu index 4f0a3ce78b45..8c7a5ebfe5be 100644 --- a/src/operator/convolution.cu +++ b/src/operator/convolution.cu @@ -6,12 +6,18 @@ */ #include "./convolution-inl.h" +#include "./cudnn_convolution-inl.h" + namespace mxnet { namespace op { template<> Operator* CreateOp(ConvolutionParam param) { + #if MXNET_USE_CUDNN == 1 + return new CuDNNConvolutionOp(param); + #else return new ConvolutionOp(param); + #endif // MXNET_USE_CUDNN } } // namespace op diff --git a/src/operator/cudnn_convolution-inl.h b/src/operator/cudnn_convolution-inl.h new file mode 100644 index 000000000000..837cf4c1fbcb --- /dev/null +++ b/src/operator/cudnn_convolution-inl.h @@ -0,0 +1,272 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file cudnn_convolution-inl.h + * \brief + * \author Bing Xu +*/ +#ifndef MXNET_OPERATOR_CUDNN_CONVOLUTION_INL_H_ +#define MXNET_OPERATOR_CUDNN_CONVOLUTION_INL_H_ +#include "./convolution-inl.h" + +namespace mxnet { +namespace op { +#if defined(__CUDACC__) && MXNET_USE_CUDNN == 1 +class CuDNNConvolutionOp : public Operator { + public: + explicit CuDNNConvolutionOp(ConvolutionParam param) { + this->param_ = param; + init_cudnn_ = false; + // TODO(xxx): fp16 + dtype_ = CUDNN_DATA_FLOAT; + } + + ~CuDNNConvolutionOp() { + if (init_cudnn_) { + CHECK_EQ(cudnnDestroyTensorDescriptor(in_desc_), CUDNN_STATUS_SUCCESS); + CHECK_EQ(cudnnDestroyTensorDescriptor(out_desc_), CUDNN_STATUS_SUCCESS); + CHECK_EQ(cudnnDestroyTensorDescriptor(bias_desc_), CUDNN_STATUS_SUCCESS); + CHECK_EQ(cudnnDestroyFilterDescriptor(filter_desc_), CUDNN_STATUS_SUCCESS); + CHECK_EQ(cudnnDestroyConvolutionDescriptor(conv_desc_), CUDNN_STATUS_SUCCESS); + } + } + + virtual void Forward(const OpContext &ctx, + const std::vector &in_data, + const std::vector &req, + const std::vector &out_data, + const std::vector &aux_args) { + using namespace mshadow; + size_t expected = param_.no_bias ? 2 : 3; + float alpha = 1.0f; + float beta = 0.0f; + CHECK_EQ(in_data.size(), expected); + CHECK_EQ(out_data.size(), 1); + Stream *s = ctx.get_stream(); + Tensor data = in_data[kData].get(s); + Tensor wmat = in_data[kWeight].get(s); + Tensor out = out_data[kOut].get(s); + CHECK_EQ(data.CheckContiguous(), true); + CHECK_EQ(wmat.CheckContiguous(), true); + CHECK_EQ(out.CheckContiguous(), true); + if (!init_cudnn_) { + Init(s, in_data, out_data); + } + CHECK_EQ(cudnnConvolutionForward(s->dnn_handle_, + &alpha, + in_desc_, + data.dptr_, + filter_desc_, + wmat.dptr_, + conv_desc_, + algo_, + temp_.dptr_, + param_.workspace, + &beta, + out_desc_, + out.dptr_), CUDNN_STATUS_SUCCESS); + if (!param_.no_bias) { + beta = 1.0f; + Tensor bias = in_data[kBias].get(s); + CHECK_EQ(cudnnAddTensor(s->dnn_handle_, + CUDNN_ADD_SAME_C, + &alpha, + bias_desc_, + bias.dptr_, + &beta, + out_desc_, + out.dptr_), CUDNN_STATUS_SUCCESS); + } + } + + virtual void Backward(const OpContext &ctx, + const std::vector &out_grad, + const std::vector &in_data, + const std::vector &out_data, + const std::vector &req, + const std::vector &in_grad, + const std::vector &aux_args) { + using namespace mshadow; + using namespace mshadow::expr; + float alpha = 1.0f; + float beta = 0.0f; + size_t expected = param_.no_bias == 0 ? 3 : 2; + CHECK_EQ(out_grad.size(), 1); + CHECK(in_data.size() == expected && in_grad.size() == expected); + // TODO(bing): think about how to support add to + CHECK_EQ(req[kWeight], kWriteTo); + Stream *s = ctx.get_stream(); + Tensor grad = out_grad[kOut].get(s); + Tensor wmat = in_data[kWeight].get(s); + Tensor gwmat = in_grad[kWeight].get(s); + Tensor data = in_data[kData].get(s); + Tensor gdata = in_grad[kData].get(s); + if (!param_.no_bias) { + Tensor gbias = in_grad[kBias].get(s); + CHECK_EQ(cudnnConvolutionBackwardBias(s->dnn_handle_, + &alpha, + out_desc_, + grad.dptr_, + &beta, + bias_desc_, + gbias.dptr_), CUDNN_STATUS_SUCCESS); + } + CHECK_EQ(cudnnConvolutionBackwardFilter_v3(s->dnn_handle_, + &alpha, + in_desc_, + data.dptr_, + out_desc_, + grad.dptr_, + conv_desc_, + back_algo_w_, + temp_.dptr_, + param_.workspace, + &beta, + filter_desc_, + gwmat.dptr_), CUDNN_STATUS_SUCCESS); + CHECK_EQ(cudnnConvolutionBackwardData_v3(s->dnn_handle_, + &alpha, + filter_desc_, + wmat.dptr_, + out_desc_, + grad.dptr_, + conv_desc_, + back_algo_, + temp_.dptr_, + param_.workspace, + &beta, + in_desc_, + gdata.dptr_), CUDNN_STATUS_SUCCESS); + } + + private: + inline void Init(mshadow::Stream *s, + const std::vector &in_data, + const std::vector &out_data) { + using namespace mshadow; + size_t expected = param_.no_bias ? 2 : 3; + CHECK_EQ(in_data.size(), expected); + CHECK_EQ(out_data.size(), 1); + if (!init_cudnn_) { + init_cudnn_ = true; + temp_.set_stream(s); + size_t workspace = static_cast(param_.workspace); + size_t back_size = 0; + size_t back_size_w = 0; + Tensor data = in_data[kData].get(s); + Tensor out = out_data[kOut].get(s); + CHECK_EQ(cudnnCreateTensorDescriptor(&in_desc_), CUDNN_STATUS_SUCCESS); + CHECK_EQ(cudnnCreateTensorDescriptor(&out_desc_), CUDNN_STATUS_SUCCESS); + CHECK_EQ(cudnnCreateTensorDescriptor(&bias_desc_), CUDNN_STATUS_SUCCESS); + CHECK_EQ(cudnnCreateFilterDescriptor(&filter_desc_), CUDNN_STATUS_SUCCESS); + CHECK_EQ(cudnnCreateConvolutionDescriptor(&conv_desc_), CUDNN_STATUS_SUCCESS); + CHECK_EQ(cudnnSetFilter4dDescriptor(filter_desc_, + dtype_, + param_.num_filter, + data.shape_[1], + param_.kernel[0], + param_.kernel[1]), CUDNN_STATUS_SUCCESS); + CHECK_EQ(cudnnSetConvolution2dDescriptor(conv_desc_, + param_.pad[0], + param_.pad[1], + param_.stride[0], + param_.stride[1], + 1, + 1, + CUDNN_CROSS_CORRELATION), CUDNN_STATUS_SUCCESS); + CHECK_EQ(cudnnSetTensor4dDescriptor(in_desc_, + CUDNN_TENSOR_NCHW, + dtype_, + data.shape_[0], + data.shape_[1], + data.shape_[2], + data.shape_[3]), CUDNN_STATUS_SUCCESS); + CHECK_EQ(cudnnSetTensor4dDescriptor(out_desc_, + CUDNN_TENSOR_NCHW, + dtype_, + out.shape_[0], + out.shape_[1], + out.shape_[2], + out.shape_[3]), CUDNN_STATUS_SUCCESS); + if (!param_.no_bias) { + Tensor bias = in_data[kBias].get(s); + CHECK_EQ(cudnnSetTensor4dDescriptor(bias_desc_, + CUDNN_TENSOR_NCHW, + dtype_, + 1, + bias.shape_[0], + 1, + 1), CUDNN_STATUS_SUCCESS); + } + CHECK_EQ(s->dnn_handle_ownership_, mshadow::Stream::OwnHandle); + CHECK_EQ(cudnnGetConvolutionForwardAlgorithm(s->dnn_handle_, + in_desc_, + filter_desc_, + conv_desc_, + out_desc_, + CUDNN_CONVOLUTION_FWD_PREFER_FASTEST, + param_.workspace, + &algo_), CUDNN_STATUS_SUCCESS); + CHECK_EQ(cudnnGetConvolutionBackwardFilterAlgorithm(s->dnn_handle_, + in_desc_, + out_desc_, + conv_desc_, + filter_desc_, + CUDNN_CONVOLUTION_BWD_FILTER_PREFER_FASTEST, + param_.workspace, + &back_algo_w_), CUDNN_STATUS_SUCCESS); + CHECK_EQ(cudnnGetConvolutionBackwardDataAlgorithm(s->dnn_handle_, + filter_desc_, + out_desc_, + conv_desc_, + in_desc_, + CUDNN_CONVOLUTION_BWD_DATA_PREFER_FASTEST, + param_.workspace, + &back_algo_), CUDNN_STATUS_SUCCESS); + CHECK_EQ(cudnnGetConvolutionBackwardDataWorkspaceSize(s->dnn_handle_, + filter_desc_, + out_desc_, + conv_desc_, + in_desc_, + back_algo_, + &back_size), CUDNN_STATUS_SUCCESS); + CHECK_EQ(cudnnGetConvolutionBackwardFilterWorkspaceSize(s->dnn_handle_, + in_desc_, + out_desc_, + conv_desc_, + filter_desc_, + back_algo_w_, + &back_size_w), CUDNN_STATUS_SUCCESS); + back_size = std::max(back_size, back_size_w); + CHECK_EQ(cudnnGetConvolutionForwardWorkspaceSize(s->dnn_handle_, + in_desc_, + filter_desc_, + conv_desc_, + out_desc_, + algo_, + &workspace), CUDNN_STATUS_SUCCESS); + workspace = std::max(workspace, back_size); + param_.workspace = workspace; + // TODO(bing): wait resource allocation + temp_.Resize(mshadow::Shape1(workspace / sizeof(real_t) + 1), 0.0f); + } + } + + bool init_cudnn_; + cudnnDataType_t dtype_; + cudnnTensorDescriptor_t in_desc_; + cudnnTensorDescriptor_t out_desc_; + cudnnTensorDescriptor_t bias_desc_; + cudnnFilterDescriptor_t filter_desc_; + cudnnConvolutionDescriptor_t conv_desc_; + cudnnConvolutionFwdAlgo_t algo_; + cudnnConvolutionBwdDataAlgo_t back_algo_; + cudnnConvolutionBwdFilterAlgo_t back_algo_w_; + ConvolutionParam param_; + // TODO(bing): remove when we have resource manager + mshadow::TensorContainer temp_; +}; +#endif // __CUDACC__ && CUDNN +} // namespace op +} // namespace mxnet + +#endif // MXNET_OPERATOR_CUDNN_CONVOLUTION_INL_H_ diff --git a/src/operator/fully_connected-inl.h b/src/operator/fully_connected-inl.h index 75fe14d3aab8..35cb035d16f8 100644 --- a/src/operator/fully_connected-inl.h +++ b/src/operator/fully_connected-inl.h @@ -63,6 +63,10 @@ class FullyConnectedOp : public Operator { // maybe need blas handle from context // TODO(bing): judge shape to remove flatten op Stream *s = ctx.get_stream(); + #if defined(__CUDACC__) + CHECK_EQ(s->blas_handle_ownership_, Stream::OwnHandle) + << "Must init CuBLAS handle in stream"; + #endif // __CUDACC__ Tensor data = in_data[kData].FlatTo2D(s); Tensor wmat = in_data[kWeight].get(s); Tensor out = out_data[kOut].FlatTo2D(s); @@ -92,6 +96,10 @@ class FullyConnectedOp : public Operator { Tensor data = in_data[kData].FlatTo2D(s); Tensor wmat = in_data[kWeight].get(s); Tensor grad = out_grad[kOut].FlatTo2D(s); + #if defined(__CUDACC__) + CHECK_EQ(s->blas_handle_ownership_, Stream::OwnHandle) + << "Must init CuBLAS handle in stream"; + #endif // backprop CHECK_NE(req[kWeight], kWriteInplace) << "cannot write weight inplace"; // gradient of weight From d9f2e3d8232032258ba267cc38db81efbc7bc808 Mon Sep 17 00:00:00 2001 From: Bing Xu Date: Wed, 9 Sep 2015 20:47:24 -0600 Subject: [PATCH 03/13] cudnn activation, pooling --- Makefile | 2 +- example/cifar10/cifar10.py | 13 ++- include/mxnet/symbolic.h | 5 +- python/mxnet/context.py | 32 +++--- python/mxnet/executor.py | 19 ++-- python/mxnet/symbol.py | 14 ++- src/dag_engine/naive_engine.cc | 4 +- src/operator/activation-inl.h | 4 + src/operator/activation.cu | 7 ++ src/operator/convolution.cu | 3 +- src/operator/cudnn_activation-inl.h | 132 +++++++++++++++++++++++ src/operator/cudnn_convolution-inl.h | 5 +- src/operator/cudnn_pooling-inl.h | 154 +++++++++++++++++++++++++++ src/operator/pooling-inl.h | 4 + src/operator/pooling.cu | 7 ++ src/symbol/graph_executor.cc | 54 +++++++--- src/symbol/graph_executor.h | 4 +- 17 files changed, 408 insertions(+), 55 deletions(-) create mode 100644 src/operator/cudnn_activation-inl.h create mode 100644 src/operator/cudnn_pooling-inl.h diff --git a/Makefile b/Makefile index ea72c4dc3a1d..009e63e513b6 100644 --- a/Makefile +++ b/Makefile @@ -124,7 +124,7 @@ doxygen: doxygen doc/Doxyfile clean: - $(RM) -r build lib/* *~ */*~ */*/*~ */*/*/*~ + $(RM) -r build lib/*.a lib/*.so *~ */*~ */*/*~ */*/*/*~ cd $(DMLC_CORE); make clean; cd - -include build/*.d diff --git a/example/cifar10/cifar10.py b/example/cifar10/cifar10.py index a937060f6520..892f6fff5d1d 100644 --- a/example/cifar10/cifar10.py +++ b/example/cifar10/cifar10.py @@ -164,13 +164,14 @@ def RandomInit(narray): in_data = mx.narray.empty(data_shape, mx.gpu()) executor = loss.simple_bind(mx.gpu(), {"data": in_data}) out_narray = executor.heads()[0] -pred = mx.narray.zeros(out_narray.shape) +pred = mx.narray.zeros(out_narray.shape, mx.cpu()) arg_narrays, grad_narrays = executor.list_arguments() +inputs = dict(zip(loss.list_arguments(), arg_narrays)) +tmp_label = mx.narray.zeros(inputs["sm_label"].shape) momentum_narrays = [mx.narray.zeros(item.shape, mx.gpu()) for item in grad_narrays] -inputs = dict(zip(loss.list_arguments(), arg_narrays)) -block = zip(grad_narrays, arg_narrays, momentum_narrays) +block = list(zip(grad_narrays, arg_narrays, momentum_narrays)) np.random.seed(0) # set random weight @@ -210,7 +211,6 @@ def Update(grad, weight, mom): batch_size=batch_size, nthread=1) -tmp_label = mx.narray.zeros(inputs["sm_label"].shape) def progress(count, total, epoch, toc): bar_len = 50 @@ -236,8 +236,6 @@ def test_cifar(): val_nbatch = 0 all_train_bacth = round(50000 / float(batch_size) + 1) for data, label in train_dataiter: - if train_nbatch > 30: - break toc = time.time() label = label.asnumpy().flatten() tmp_label[:] = label @@ -247,7 +245,8 @@ def test_cifar(): pred[:] = out_narray train_acc += CalAcc(pred.asnumpy(), label) train_nbatch += 1 - executor.backward([out_narray]) + #executor.backward([out_narray]) + executor.backward() for grad, weight, mom in block: Update(grad, weight, mom) diff --git a/include/mxnet/symbolic.h b/include/mxnet/symbolic.h index 97eed74a53be..ef9e562f64bc 100644 --- a/include/mxnet/symbolic.h +++ b/include/mxnet/symbolic.h @@ -394,9 +394,12 @@ class Executor { * \brief Perform a Backward operation of the Operator. * This must be called after Forward. * After this operation, NArrays specified by grad_in_args_store will be updated accordingly. + * User is allowed to pass in an empty Array if the head node is + * loss function and head gradeitn is not needed. + * * \param head_grads the gradient of head nodes to be backproped. */ - virtual void Backward(const std::vector &head_grads) = 0; + virtual void Backward(const std::vector &head_grads = {}) = 0; /*! * \brief get array of heads in the executor. * \return array of heads in the executor. diff --git a/python/mxnet/context.py b/python/mxnet/context.py index 9d84f6915fbb..707591f9fc19 100644 --- a/python/mxnet/context.py +++ b/python/mxnet/context.py @@ -52,36 +52,44 @@ def __exit__(self, ptype, value, trace): # initialize the default context in Context Context.default_ctx = Context('cpu', 0) + def cpu(device_id=0): - """ - Return CPU context + """Return a CPU context. + + This function is a short cut for Context('cpu', device_id) Parameters ---------- - device_id : int (default=0) - the device id of the device, needed for GPU + device_id : int, optional + The device id of the device. device_id is not needed for CPU. + This is included to make interface compatible with GPU. Returns - --------- - A cpu context + ------- + context : Context + The corresponding CPU context. """ return Context('cpu', device_id) + def gpu(device_id=0): - """ - Return CPU context + """Return a GPU context. + + This function is a short cut for Context('cpu', device_id) Parameters ---------- - device_id : int (default=0) - the device id of the device, needed for GPU + device_id : int, optional + The device id of the device, needed for GPU Returns - --------- - A cpu context + ------- + context : Context + The corresponding GPU context. """ return Context('gpu', device_id) + def current_context(): """Return the current context. diff --git a/python/mxnet/executor.py b/python/mxnet/executor.py index 312ae8edb6c6..a3ba09ca1a76 100644 --- a/python/mxnet/executor.py +++ b/python/mxnet/executor.py @@ -45,7 +45,7 @@ def list_arguments(self, with_grad=True): else: return self.arg_narrays - def list_auxiliary_states(): + def list_auxiliary_states(self): """Return auxiliary states of executor Note: auxiliary states is same to symbol.list_auxiliary_states() """ @@ -62,19 +62,24 @@ def forward(self, is_train=True): """ check_call(_LIB.MXExecutorForward(self.handle, is_train)) - def backward(self, grads): + def backward(self, head_grads=None): """Do backward on heads' gradient. Parameters ---------- - grads: Array of NArray - heads' gradient + head_grads : NArray or list of NArray, optional + Gradient on the heads """ - for obj in grads: + if head_grads is None: + head_grads = [] + elif isinstance(head_grads, NArray): + head_grads = [head_grads] + + for obj in head_grads: if not isinstance(obj, NArray): raise TypeError("inputs must be NArray") - narray = c_array(NArrayHandle, [item.handle for item in grads]) - check_call(_LIB.MXExecutorBackward(self.handle, len(grads), narray)) + narray = c_array(NArrayHandle, [item.handle for item in head_grads]) + check_call(_LIB.MXExecutorBackward(self.handle, len(head_grads), narray)) def heads(self): """list all heads' output narray diff --git a/python/mxnet/symbol.py b/python/mxnet/symbol.py index e64b0d8e8253..90df0b663615 100644 --- a/python/mxnet/symbol.py +++ b/python/mxnet/symbol.py @@ -352,7 +352,9 @@ def simple_bind(self, ctx, args, grad_req='write'): if not isinstance(args, dict): raise TypeError("args must be dict of str->NArray") input_shapes = dict((name, arr.shape) for name, arr in args.items()) + # pylint: disable=unused-variable arg_shapes, out_shapes, aux_shapes = self.infer_shape(**input_shapes) + # pylint: enable=unused-variable if arg_shapes == None: raise ValueError("Input node is not complete") # alloc space @@ -363,13 +365,10 @@ def simple_bind(self, ctx, args, grad_req='write'): else: arg_narrays.append(zeros(shape, ctx)) # TODO(bing): specail treat input data grad + # TODO(bing): not generate grad case grad_narrays = [zeros(shape, ctx) for shape in arg_shapes] aux_narrays = [zeros(shape, ctx) for shape in aux_shapes] executor = self.bind(ctx, arg_narrays, grad_narrays, grad_req, aux_narrays) - executor.arg_narrays = arg_narrays - executor.grad_narrays = grad_narrays - executor.auxiliary_states = aux_narrays - return executor def bind(self, ctx, args, args_grad=None, grad_req='write', aux_states=None): @@ -426,6 +425,7 @@ def bind(self, ctx, args, args_grad=None, grad_req='write', aux_states=None): User can give up gradient by using a dict in args_grad and only specify gradient they interested in. """ + # pylint: disable=too-many-locals if not isinstance(ctx, Context): raise TypeError("Context type error") @@ -470,7 +470,11 @@ def bind(self, ctx, args, args_grad=None, grad_req='write', aux_states=None): len(aux_states), aux_args_handle, ctypes.byref(handle))) - return Executor(handle) + executor = Executor(handle) + executor.arg_narrays = args + executor.grad_narrays = args_grad + executor.auxiliary_states = aux_states + return executor def grad(self, wrt): """Get the autodiff of current symbol. diff --git a/src/dag_engine/naive_engine.cc b/src/dag_engine/naive_engine.cc index 1cf1c07c5a62..912ea00349be 100644 --- a/src/dag_engine/naive_engine.cc +++ b/src/dag_engine/naive_engine.cc @@ -14,9 +14,9 @@ class NaiveEngine : public DAGEngine { stream_ = mshadow::NewStream(true, true); #else stream_ = mshadow::NewStream(true, false); - #endif // MXNET_USE_CUDNN + #endif // MXNET_USE_CUDNN ctx_.stream = stream_; - #endif // MXNET_USE_CUDA + #endif // MXNET_USE_CUDA } ~NaiveEngine() { diff --git a/src/operator/activation-inl.h b/src/operator/activation-inl.h index 43aa4f01637a..2319f074cc73 100644 --- a/src/operator/activation-inl.h +++ b/src/operator/activation-inl.h @@ -117,7 +117,11 @@ class ActivationProp : public OperatorProperty { const std::vector &out_grad, const std::vector &in_data, const std::vector &out_data) const override { + #if MXNET_USE_CUDNN == 1 + return {out_grad[kOut], out_data[kOut], in_data[kData]}; + #else return {out_grad[kOut], out_data[kOut]}; + #endif // MXNET_USE_CUDNN } std::vector > BackwardInplaceOption( diff --git a/src/operator/activation.cu b/src/operator/activation.cu index b1b8fc4fb8b0..4325d4c53a46 100644 --- a/src/operator/activation.cu +++ b/src/operator/activation.cu @@ -6,11 +6,17 @@ */ #include "./activation-inl.h" #include "./mshadow_op.h" +#if MXNET_USE_CUDNN == 1 +#include "./cudnn_activation-inl.h" +#endif namespace mxnet { namespace op { template<> Operator *CreateOp(ActivationParam param) { + #if MXNET_USE_CUDNN == 1 + return new CuDNNActivationOp(param); + #else switch(param.act_type) { case kReLU: return new ActivationOp(); case kSigmoid: return new ActivationOp(); @@ -19,6 +25,7 @@ Operator *CreateOp(ActivationParam param) { LOG(FATAL) << "unknown activation"; return NULL; } + #endif // MXNET_USE_CUDNN } } // op } // namespace mxnet diff --git a/src/operator/convolution.cu b/src/operator/convolution.cu index 8c7a5ebfe5be..8127cec43fd6 100644 --- a/src/operator/convolution.cu +++ b/src/operator/convolution.cu @@ -6,8 +6,9 @@ */ #include "./convolution-inl.h" +#if MXNET_USE_CUDNN == 1 #include "./cudnn_convolution-inl.h" - +#endif // MXNET_USE_CUDNN namespace mxnet { namespace op { diff --git a/src/operator/cudnn_activation-inl.h b/src/operator/cudnn_activation-inl.h new file mode 100644 index 000000000000..1158a1324128 --- /dev/null +++ b/src/operator/cudnn_activation-inl.h @@ -0,0 +1,132 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file cudnn_activation-inl.h + * \brief + * \author Bing Xu +*/ + +#ifndef MXNET_OPERATOR_CUDNN_ACTIVATION_INL_H_ +#define MXNET_OPERATOR_CUDNN_ACTIVATION_INL_H_ +#include +#include +#include "./activation-inl.h" + +namespace mxnet { +namespace op { +class CuDNNActivationOp : public Operator { + public: + explicit CuDNNActivationOp(ActivationParam param) { + param_ = param; + init_cudnn_ = false; + dtype_ = CUDNN_DATA_FLOAT; + switch (param_.act_type) { + case kReLU: + mode_ = CUDNN_ACTIVATION_RELU; + break; + case kSigmoid: + mode_ = CUDNN_ACTIVATION_SIGMOID; + break; + case kTanh: + mode_ = CUDNN_ACTIVATION_TANH; + break; + default: + LOG(FATAL) << "Not implmented"; + break; + } + } + + virtual void Forward(const OpContext &ctx, + const std::vector &in_data, + const std::vector &req, + const std::vector &out_data, + const std::vector &aux_args) { + using namespace mshadow; + using namespace mshadow::expr; + CHECK_EQ(in_data.size(), 1); + CHECK_EQ(out_data.size(), 1); + Stream *s = ctx.get_stream(); + Tensor data = in_data[kData].get(s); + Tensor out = out_data[kOut].get(s); + float alpha = 1.0f; + float beta = 0.0f; + CHECK_EQ(s->dnn_handle_ownership_, mshadow::Stream::OwnHandle); + if (!init_cudnn_) { + this->Init(s, in_data, out_data); + } + CHECK_EQ(cudnnActivationForward(s->dnn_handle_, + mode_, + &alpha, + shape_desc_, + data.dptr_, + &beta, + shape_desc_, + out.dptr_), CUDNN_STATUS_SUCCESS); + } + + virtual void Backward(const OpContext &ctx, + const std::vector &out_grad, + const std::vector &in_data, + const std::vector &out_data, + const std::vector &req, + const std::vector &in_grad, + const std::vector &aux_args) { + using namespace mshadow; + using namespace mshadow::expr; + CHECK_EQ(out_grad.size(), 1); + CHECK_EQ(in_data.size(), 1); + CHECK_EQ(out_data.size(), 1); + CHECK_EQ(req.size(), 1); + CHECK_EQ(in_grad.size(), 1); + float alpha = 1.0f; + float beta = 0.0f; + Stream *s = ctx.get_stream(); + Tensor grad = out_grad[kOut].get(s); + Tensor data = in_data[kData].get(s); + Tensor output_data = out_data[kOut].get(s); + Tensor input_grad = in_grad[kData].get(s); + CHECK_EQ(s->dnn_handle_ownership_, mshadow::Stream::OwnHandle); + CHECK_EQ(cudnnActivationBackward(s->dnn_handle_, + mode_, + &alpha, + shape_desc_, + output_data.dptr_, + shape_desc_, + grad.dptr_, + shape_desc_, + data.dptr_, + &beta, + shape_desc_, + input_grad.dptr_), CUDNN_STATUS_SUCCESS); + } + + private: + inline void Init(mshadow::Stream *s, + const std::vector &in_data, + const std::vector &out_data) { + using namespace mshadow; + CHECK_EQ(in_data.size(), 1); + CHECK_EQ(out_data.size(), 1); + if (!init_cudnn_) { + init_cudnn_ = true; + Tensor data = in_data[kData].get(s); + Tensor out = out_data[kOut].get(s); + CHECK_EQ(data.shape_, out.shape_); + CHECK_EQ(cudnnCreateTensorDescriptor(&shape_desc_), CUDNN_STATUS_SUCCESS); + CHECK_EQ(cudnnSetTensor4dDescriptor(shape_desc_, + CUDNN_TENSOR_NCHW, + dtype_, + data.shape_[0], + data.shape_[1], + data.shape_[2], + data.shape_[3]), CUDNN_STATUS_SUCCESS); + } + } + bool init_cudnn_; + cudnnDataType_t dtype_; + cudnnActivationMode_t mode_; + cudnnTensorDescriptor_t shape_desc_; + ActivationParam param_; +}; // class CuDNNActivationOp +} // namespace op +} // namespace mxnet +#endif // MXNET_OPERATOR_CUDNN_ACTIVATION_INL_H_ diff --git a/src/operator/cudnn_convolution-inl.h b/src/operator/cudnn_convolution-inl.h index 837cf4c1fbcb..8b81818304e1 100644 --- a/src/operator/cudnn_convolution-inl.h +++ b/src/operator/cudnn_convolution-inl.h @@ -6,6 +6,9 @@ */ #ifndef MXNET_OPERATOR_CUDNN_CONVOLUTION_INL_H_ #define MXNET_OPERATOR_CUDNN_CONVOLUTION_INL_H_ + +#include +#include #include "./convolution-inl.h" namespace mxnet { @@ -265,7 +268,7 @@ class CuDNNConvolutionOp : public Operator { // TODO(bing): remove when we have resource manager mshadow::TensorContainer temp_; }; -#endif // __CUDACC__ && CUDNN +#endif // __CUDACC__ && CUDNN } // namespace op } // namespace mxnet diff --git a/src/operator/cudnn_pooling-inl.h b/src/operator/cudnn_pooling-inl.h new file mode 100644 index 000000000000..83faeee70435 --- /dev/null +++ b/src/operator/cudnn_pooling-inl.h @@ -0,0 +1,154 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file cudnn_pooling-inl.h + * \brief + * \author Bing Xu +*/ + +#ifndef MXNET_OPERATOR_CUDNN_POOLING_INL_H_ +#define MXNET_OPERATOR_CUDNN_POOLING_INL_H_ +#include +#include +#include "./pooling-inl.h" + +namespace mxnet { +namespace op { + +class CuDNNPoolingOp : public Operator { + public: + explicit CuDNNPoolingOp(PoolingParam p) { + param_ = p; + init_cudnn_ = false; + // TODO(xxx): fp16 + dtype_ = CUDNN_DATA_FLOAT; + switch (param_.pool_type) { + case kMaxPooling: + mode_ = CUDNN_POOLING_MAX; + break; + case kAvgPooling: + mode_ = CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING; + break; + default: + LOG(FATAL) << "Not implmented"; + } + } + + virtual void Forward(const OpContext &ctx, + const std::vector &in_data, + const std::vector &req, + const std::vector &out_data, + const std::vector &aux_args) { + using namespace mshadow; + using namespace mshadow::expr; + CHECK_EQ(in_data.size(), 1); + CHECK_EQ(out_data.size(), 1); + Stream *s = ctx.get_stream(); + Tensor data = in_data[kData].get(s); + Tensor out = out_data[kOut].get(s); + CHECK_EQ(s->dnn_handle_ownership_, mshadow::Stream::OwnHandle); + if (!init_cudnn_) { + this->Init(s, in_data, out_data); + } + float alpha = 1.0f; + float beta = 0.0f; + CHECK_EQ(data.CheckContiguous(), true); + CHECK_EQ(out.CheckContiguous(), true); + CHECK_EQ(cudnnPoolingForward(s->dnn_handle_, + pooling_desc_, + &alpha, + in_desc_, + data.dptr_, + &beta, + out_desc_, + out.dptr_), CUDNN_STATUS_SUCCESS); + } + + virtual void Backward(const OpContext &ctx, + const std::vector &out_grad, + const std::vector &in_data, + const std::vector &out_data, + const std::vector &req, + const std::vector &in_grad, + const std::vector &aux_args) { + using namespace mshadow; + using namespace mshadow::expr; + CHECK_EQ(out_grad.size(), 1); + CHECK_EQ(in_data.size(), 1); + CHECK_EQ(out_data.size(), 1); + CHECK_EQ(req.size(), 1); + CHECK_EQ(in_grad.size(), 1); + + Stream *s = ctx.get_stream(); + Tensor m_out_grad = out_grad[kOut].get(s); + Tensor m_in_data = in_data[kData].get(s); + Tensor m_out_data = out_data[kOut].get(s); + Tensor m_in_grad = in_grad[kData].get(s); + CHECK_EQ(s->dnn_handle_ownership_, mshadow::Stream::OwnHandle); + float alpha = 1.0f; + float beta = 0.0f; + CHECK_EQ(cudnnPoolingBackward(s->dnn_handle_, + pooling_desc_, + &alpha, + out_desc_, + m_out_data.dptr_, + out_desc_, + m_out_grad.dptr_, + in_desc_, + m_in_data.dptr_, + &beta, + in_desc_, + m_in_grad.dptr_), CUDNN_STATUS_SUCCESS); + } + + private: + inline void Init(mshadow::Stream *s, + const std::vector &in_data, + const std::vector &out_data) { + using namespace mshadow; + CHECK_EQ(in_data.size(), 1); + CHECK_EQ(out_data.size(), 1); + if (!init_cudnn_) { + init_cudnn_ = true; + Tensor data = in_data[kData].get(s); + Tensor out = out_data[kOut].get(s); + CHECK_EQ(cudnnCreatePoolingDescriptor(&pooling_desc_), CUDNN_STATUS_SUCCESS); + CHECK_EQ(cudnnCreateTensorDescriptor(&in_desc_), CUDNN_STATUS_SUCCESS); + CHECK_EQ(cudnnCreateTensorDescriptor(&out_desc_), CUDNN_STATUS_SUCCESS); + CHECK_EQ(cudnnSetTensor4dDescriptor(in_desc_, + CUDNN_TENSOR_NCHW, + dtype_, + data.shape_[0], + data.shape_[1], + data.shape_[2], + data.shape_[3]), CUDNN_STATUS_SUCCESS); + CHECK_EQ(cudnnSetTensor4dDescriptor(out_desc_, + CUDNN_TENSOR_NCHW, + dtype_, + out.shape_[0], + out.shape_[1], + out.shape_[2], + out.shape_[3]), CUDNN_STATUS_SUCCESS); + CHECK_EQ(cudnnSetPooling2dDescriptor(pooling_desc_, + mode_, + param_.kernel[0], + param_.kernel[1], + param_.pad[0], + param_.pad[1], + param_.stride[0], + param_.stride[1]), CUDNN_STATUS_SUCCESS); + } + } + bool init_cudnn_; + cudnnDataType_t dtype_; + cudnnHandle_t handle_; + cudnnPoolingMode_t mode_; + cudnnTensorDescriptor_t in_desc_; + cudnnTensorDescriptor_t out_desc_; + cudnnPoolingDescriptor_t pooling_desc_; + PoolingParam param_; +}; // class CuDNNPoolingOp +} // namespace op +} // namespace mxnet + +#endif // MXNET_OPERATOR_CUDNN_POOLING_INL_H_ + diff --git a/src/operator/pooling-inl.h b/src/operator/pooling-inl.h index b0d483ef0217..5748325d5835 100644 --- a/src/operator/pooling-inl.h +++ b/src/operator/pooling-inl.h @@ -201,7 +201,11 @@ class PoolingProp : public OperatorProperty { const std::vector &in_data, const std::vector &out_data, const std::vector &in_grad) const override { + #if MXNET_USE_CUDNN == 1 + return {}; + #else return {{in_data[kData], in_grad[kData]}}; + #endif } Operator* CreateOperator(Context ctx) const; diff --git a/src/operator/pooling.cu b/src/operator/pooling.cu index 5037050ccd6f..df9547bf4a1e 100644 --- a/src/operator/pooling.cu +++ b/src/operator/pooling.cu @@ -6,11 +6,17 @@ */ #include "./pooling-inl.h" +#if MXNET_USE_CUDNN == 1 +#include "./cudnn_pooling-inl.h" +#endif // MXNET_USE_CUDNN namespace mxnet { namespace op { template<> Operator *CreateOp(PoolingParam param) { + #if MXNET_USE_CUDNN == 1 + return new CuDNNPoolingOp(param); + #else switch (param.pool_type) { case kMaxPooling: return new PoolingOp(param); case kAvgPooling: return new PoolingOp(param); @@ -19,6 +25,7 @@ Operator *CreateOp(PoolingParam param) { LOG(FATAL) << "unknown activation type"; return NULL; } + #endif // MXNET_USE_CUDNN } } // namespace op diff --git a/src/symbol/graph_executor.cc b/src/symbol/graph_executor.cc index aeff3427d8f3..035d7f46b284 100644 --- a/src/symbol/graph_executor.cc +++ b/src/symbol/graph_executor.cc @@ -323,6 +323,13 @@ void GraphExecutor::InitDataEntryInfo(const std::vector &in_args, } void GraphExecutor::InitDataEntryMemory() { + // setup the temp ref counter for allocator algorithms + for (OpNode &op : op_nodes_) { + for (DataEntryInfo &node : op.outputs) { + node.temp_ref_count = node.ref_count; + } + } + // use allocator to allocate memory. GraphStorageAllocator allocator(&graph_); for (size_t i = 0; i < topo_order_.size(); ++i) { @@ -337,7 +344,7 @@ void GraphExecutor::InitDataEntryMemory() { for (StaticGraph::DataEntry e : graph_.nodes[nid].inputs) { DataEntryInfo &info = op_nodes_[e.source_id].outputs[e.index]; CHECK_NE(info.type, kNotInitialized); - CHECK_NE(info.ref_count, 0); + CHECK_NE(info.temp_ref_count, 0); in_data.push_back(&info); } std::vector out_data(op_nodes_[nid].outputs.size()); @@ -350,7 +357,7 @@ void GraphExecutor::InitDataEntryMemory() { for (std::pair kv : inplace) { DataEntryInfo* in = kv.first; DataEntryInfo* out = kv.second; - if (in->ref_count == 1 && + if (in->temp_ref_count == 1 && in->type == kInternalAllocated && out->type == kNotInitialized) { // we can only do inplace if we are last user of in @@ -359,13 +366,13 @@ void GraphExecutor::InitDataEntryMemory() { out->op_req = kWriteInplace; out->storage_id = in->storage_id; // set inplace op id - in->ref_count = 0; + in->temp_ref_count = 0; in->inplace_op_id = static_cast(nid); } } // allocate output, for (DataEntryInfo *out : out_data) { - if (out->op_req == kNullOp && out->ref_count != 0) { + if (out->op_req == kNullOp && out->temp_ref_count != 0) { out->op_req = kWriteTo; } if (out->type == kNotInitialized) { @@ -376,20 +383,20 @@ void GraphExecutor::InitDataEntryMemory() { } // then free inputs for (DataEntryInfo *in : in_data) { - // ref_count == 0 means it is taken by inplace op - if (in->ref_count == 0) { + // temp_ref_count == 0 means it is taken by inplace op + if (in->temp_ref_count == 0) { CHECK_EQ(in->inplace_op_id, static_cast(nid)); continue; } // if we decrease it to zero, means we are ready to relase - --in->ref_count; - if (in->ref_count == 0 && in->type == kInternalAllocated) { + --in->temp_ref_count; + if (in->temp_ref_count == 0 && in->type == kInternalAllocated) { allocator.Release(in->storage_id, nid); } } - // check out again, if there is ref_count == 0, release it + // check out again, if there is temp_ref_count == 0, release it for (DataEntryInfo *out : out_data) { - if (out->ref_count == 0 && out->type == kInternalAllocated) { + if (out->temp_ref_count == 0 && out->type == kInternalAllocated) { allocator.Release(out->storage_id, nid); } } @@ -493,13 +500,26 @@ void GraphExecutor::Forward(bool is_train) { } void GraphExecutor::Backward(const std::vector &head_grads) { - CHECK_EQ(head_grad_nodes_.size(), head_grads.size()); - for (size_t i = 0; i < head_grad_nodes_.size(); ++i) { - uint32_t nid = head_grad_nodes_[i]; - CHECK(graph_.nodes[nid].is_variable()); - DataEntryInfo &info = op_nodes_[nid].outputs[0]; - CHECK_EQ(info.type, kTobeBindByExternal); - info.data = head_grads[i]; + if (head_grads.size() != 0) { + // TODO(bing, min): consider pass a map for backward + CHECK_EQ(head_grad_nodes_.size(), head_grads.size()); + for (size_t i = 0; i < head_grad_nodes_.size(); ++i) { + uint32_t nid = head_grad_nodes_[i]; + CHECK(graph_.nodes[nid].is_variable()); + DataEntryInfo &info = op_nodes_[nid].outputs[0]; + CHECK_EQ(info.type, kTobeBindByExternal); + info.data = head_grads[i]; + } + } else { + // check all the head_grad_nodes need to have zero ref_count + // loss function do not need out_grad + for (size_t i = 0; i < head_grad_nodes_.size(); ++i) { + uint32_t nid = head_grad_nodes_[i]; + DataEntryInfo &info = op_nodes_[nid].outputs[0]; + CHECK_EQ(info.ref_count, 0) + << "Because the last operator is not Loss function, " + << "head_gradient is required in calling backward."; + } } RunOps(true, num_forward_nodes_, topo_order_.size()); } diff --git a/src/symbol/graph_executor.h b/src/symbol/graph_executor.h index 66cd074b406b..b495ad3df179 100644 --- a/src/symbol/graph_executor.h +++ b/src/symbol/graph_executor.h @@ -79,6 +79,8 @@ class GraphExecutor : public Executor { // reference count on how many times this entry is being used. // That is how many operators and heads need this DataEntry // this is a temporal variable that is used during initialization. + uint32_t temp_ref_count; + // real permanent ref count uint32_t ref_count; // constructor DataEntryInfo() @@ -86,7 +88,7 @@ class GraphExecutor : public Executor { inplace_op_id(-1), type(kNotInitialized), storage_id(GraphStorageAllocator::kBadStorageID), - ref_count(0) {} + temp_ref_count(0), ref_count(0) {} }; // all the information needed to push the op to engine struct OpExecEntry { From a28a83a49426d167d2bc2a0350aeac5fd3656c47 Mon Sep 17 00:00:00 2001 From: tqchen Date: Sat, 12 Sep 2015 10:41:38 -0700 Subject: [PATCH 04/13] add stream wait to all pushes, change kvstore to redef pinned memory --- include/mxnet/base.h | 4 +- include/mxnet/context.h | 9 ++ include/mxnet/operator.h | 2 +- src/kvstore/kvstore_local.h | 14 +-- src/narray/narray.cc | 187 +++++++++++++--------------- src/symbol/graph_executor.cc | 19 ++- src/symbol/graph_executor.h | 2 +- src/symbol/graph_memory_allocator.h | 2 +- 8 files changed, 127 insertions(+), 112 deletions(-) diff --git a/include/mxnet/base.h b/include/mxnet/base.h index e3fbe002fdfc..7e5b6d06f0d8 100644 --- a/include/mxnet/base.h +++ b/include/mxnet/base.h @@ -33,6 +33,9 @@ #define MXNET_USE_CUDNN MSHADOW_USE_CUDNN #endif +/*! \brief Error message for using gpu when MXNET_USE_CUDA==0 */ +#define MXNET_GPU_NOT_ENABLED_ERROR "GPU is not enabled" + /*! \brief namespace of mxnet */ namespace mxnet { /*! \brief mxnet cpu */ @@ -50,7 +53,6 @@ typedef mshadow::TShape TShape; typedef mshadow::TBlob TBlob; } // namespace mxnet - //! \cond Doxygen_Suppress namespace dmlc { // Add a few patches to support TShape in dmlc/parameter. diff --git a/include/mxnet/context.h b/include/mxnet/context.h index e7e2e2b0e44b..c0a712bc8ec8 100644 --- a/include/mxnet/context.h +++ b/include/mxnet/context.h @@ -83,6 +83,15 @@ struct RunContext { * \brief the stream of the device, can be NULL or Stream* in GPU mode */ void *stream; + /*! + * \brief get mshadow stream from Context + * \return the mshadow stream + * \tparam xpu the device type of the stream + */ + template + inline mshadow::Stream* get_stream() const { + return static_cast*>(stream); + } }; /*! diff --git a/include/mxnet/operator.h b/include/mxnet/operator.h index 64c1515a6f3b..0842f53d347e 100644 --- a/include/mxnet/operator.h +++ b/include/mxnet/operator.h @@ -53,7 +53,7 @@ struct OpContext { */ template inline mshadow::Stream* get_stream() const { - return static_cast*>(run_ctx.stream); + return run_ctx.get_stream(); } }; diff --git a/src/kvstore/kvstore_local.h b/src/kvstore/kvstore_local.h index 34c206bfc6de..bea6cb019356 100644 --- a/src/kvstore/kvstore_local.h +++ b/src/kvstore/kvstore_local.h @@ -19,7 +19,12 @@ namespace mxnet { */ class KVStoreLocal : public KVStore { public: - KVStoreLocal() : pinned_ctx_(cpu::kDevMask, Context::kPinnedMemoryID) { + KVStoreLocal() { +#if MXNET_USE_CUDA + pinned_ctx_ = Context(cpu::kDevMask, Context::kPinnedMemoryID); +#else + pinned_ctx_ = Context(cpu::kDevMask, 0); +#endif Clear(); } @@ -44,11 +49,7 @@ class KVStoreLocal : public KVStore { for (size_t i = 0; i < keys.size(); ++i) { CHECK(local_.find(keys[i]) == local_.end()) << "duplicate init of key " << keys[i]; -#if MXNET_USE_CUDA local_.insert({keys[i], values[i].Copy(pinned_ctx_)}); -#else - local_.insert({keys[i], values[i].Copy(local_ctx_)}); -#endif // MXNET_USE_CUDA } } @@ -121,7 +122,7 @@ class KVStoreLocal : public KVStore { CHECK(val.size()); auto& buf = merge_buf_[key]; if (buf.merged.is_none()) { - buf.merged = val[0].Copy(local_ctx_); + buf.merged = val[0].Copy(pinned_ctx_); } else { CopyFromTo(val[0], &buf.merged); } @@ -167,7 +168,6 @@ class KVStoreLocal : public KVStore { /// \brief local storage std::unordered_map local_; - Context local_ctx_; Context pinned_ctx_; Updater updater_; diff --git a/src/narray/narray.cc b/src/narray/narray.cc index 7bce1d5c5243..22097eb46e86 100644 --- a/src/narray/narray.cc +++ b/src/narray/narray.cc @@ -5,6 +5,7 @@ */ #include #include +#include #include #include #include "./narray_function.h" @@ -42,45 +43,34 @@ inline void BinaryOp(const NArray &lhs, } // important: callback must always capture by value NArray ret = *out; + // get the const variables + std::vector const_vars; + if (lhs.ptr_->var != ret.ptr_->var) const_vars.push_back(lhs.ptr_->var); + if (rhs.ptr_->var != ret.ptr_->var) const_vars.push_back(rhs.ptr_->var); + // redirect everything to mshadow operations switch (lhs.ctx().dev_mask) { case cpu::kDevMask: { - auto func = [lhs, rhs, ret](RunContext ctx) { - ret.ptr_->CheckAndAlloc(); - TBlob tmp = ret.data(); - narray::Eval(lhs.data(), rhs.data(), &tmp, ctx); - }; - if (lhs.ptr_->var == ret.ptr_->var && rhs.ptr_->var == ret.ptr_->var) { - Engine::Get()->Push(func, lhs.ctx(), {}, {ret.ptr_->var}); - } else if (lhs.ptr_->var == ret.ptr_->var) { - Engine::Get()->Push(func, lhs.ctx(), {rhs.ptr_->var}, {ret.ptr_->var}); - } else if (rhs.ptr_->var == ret.ptr_->var) { - Engine::Get()->Push(func, lhs.ctx(), {lhs.ptr_->var}, {ret.ptr_->var}); - } else { - Engine::Get()->Push(func, lhs.ctx(), {lhs.ptr_->var, rhs.ptr_->var}, {ret.ptr_->var}); - } + Engine::Get()->Push([lhs, rhs, ret](RunContext ctx) { + ret.ptr_->CheckAndAlloc(); + TBlob tmp = ret.data(); + narray::Eval(lhs.data(), rhs.data(), &tmp, ctx); + }, lhs.ctx(), const_vars, {ret.ptr_->var}); break; } #if MXNET_USE_CUDA case gpu::kDevMask: { - auto func = [lhs, rhs, ret](RunContext ctx) { - ret.ptr_->CheckAndAlloc(); - TBlob tmp = ret.data(); - narray::Eval(lhs.data(), rhs.data(), &tmp, ctx); - }; - if (lhs.ptr_->var == ret.ptr_->var && rhs.ptr_->var == ret.ptr_->var) { - Engine::Get()->Push(func, lhs.ctx(), {}, {ret.ptr_->var}); - } else if (lhs.ptr_->var == ret.ptr_->var) { - Engine::Get()->Push(func, lhs.ctx(), {rhs.ptr_->var}, {ret.ptr_->var}); - } else if (rhs.ptr_->var == ret.ptr_->var) { - Engine::Get()->Push(func, lhs.ctx(), {lhs.ptr_->var}, {ret.ptr_->var}); - } else { - Engine::Get()->Push(func, lhs.ctx(), {lhs.ptr_->var, rhs.ptr_->var}, {ret.ptr_->var}); - } + Engine::Get()->Push([lhs, rhs, ret](RunContext ctx) { + ret.ptr_->CheckAndAlloc(); + TBlob tmp = ret.data(); + narray::Eval(lhs.data(), rhs.data(), &tmp, ctx); + // Wait GPU kernel to complete + ctx.get_stream()->Wait(); + }, lhs.ctx(), const_vars, {ret.ptr_->var}); break; } #endif - default: LOG(FATAL) << "GPU is not enabled"; + default: LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR; } } @@ -90,26 +80,26 @@ inline void SetValueOp(const real_t &rhs, NArray *out) { NArray ret = *out; switch (ret.ctx().dev_mask) { case cpu::kDevMask: { - auto func = [rhs, ret](RunContext ctx) { - ret.ptr_->CheckAndAlloc(); - TBlob tmp = ret.data(); - narray::Eval(rhs, &tmp, ctx); - }; - Engine::Get()->Push(func, ret.ctx(), {}, {ret.ptr_->var}); + Engine::Get()->Push([rhs, ret](RunContext ctx) { + ret.ptr_->CheckAndAlloc(); + TBlob tmp = ret.data(); + narray::Eval(rhs, &tmp, ctx); + }, ret.ctx(), {}, {ret.ptr_->var}); break; } #if MXNET_USE_CUDA case gpu::kDevMask: { - auto func = [rhs, ret](RunContext ctx) { - ret.ptr_->CheckAndAlloc(); - TBlob tmp = ret.data(); - narray::Eval(rhs, &tmp, ctx); - }; - Engine::Get()->Push(func, ret.ctx(), {}, {ret.ptr_->var}); + Engine::Get()->Push([rhs, ret](RunContext ctx) { + ret.ptr_->CheckAndAlloc(); + TBlob tmp = ret.data(); + narray::Eval(rhs, &tmp, ctx); + // Wait GPU kernel to complete + ctx.get_stream()->Wait(); + }, ret.ctx(), {}, {ret.ptr_->var}); break; } #endif - default: LOG(FATAL) << "GPU is not enabled"; + default: LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR; } } /*! @@ -124,45 +114,40 @@ inline void ScalarOp(const NArray &lhs, const real_t &rhs, NArray *out) { if (out->is_none()) { - *out = NArray(OP::GetShape(lhs.shape(), lhs.shape()), lhs.ctx(), true); + *out = NArray(lhs.shape(), lhs.ctx(), true); } else { CHECK(out->ctx() == lhs.ctx()) << "target context mismatch"; - CHECK(out->shape() == OP::GetShape(lhs.shape(), lhs.shape())) - << "target shape mismatch"; + CHECK(out->shape() == lhs.shape()) << "target shape mismatch"; } // important: callback must always capture by value NArray ret = *out; + // get the const variables + std::vector const_vars; + if (lhs.ptr_->var != ret.ptr_->var) const_vars.push_back(lhs.ptr_->var); + // redirect everything to mshadow operations switch (lhs.ctx().dev_mask) { case cpu::kDevMask: { - auto func = [lhs, rhs, ret](RunContext ctx) { - ret.ptr_->CheckAndAlloc(); - TBlob tmp = ret.data(); - narray::Eval(lhs.data(), rhs, &tmp, ctx); - }; - if (lhs.ptr_->var == ret.ptr_->var) { - Engine::Get()->Push(func, lhs.ctx(), {}, {ret.ptr_->var}); - } else { - Engine::Get()->Push(func, lhs.ctx(), {lhs.ptr_->var}, {ret.ptr_->var}); - } + Engine::Get()->Push([lhs, rhs, ret](RunContext ctx) { + ret.ptr_->CheckAndAlloc(); + TBlob tmp = ret.data(); + narray::Eval(lhs.data(), rhs, &tmp, ctx); + }, lhs.ctx(), const_vars, {ret.ptr_->var}); break; } #if MXNET_USE_CUDA case gpu::kDevMask: { - auto func = [lhs, rhs, ret](RunContext ctx) { - ret.ptr_->CheckAndAlloc(); - TBlob tmp = ret.data(); - narray::Eval(lhs.data(), rhs, &tmp, ctx); - }; - if (lhs.ptr_->var == ret.ptr_->var) { - Engine::Get()->Push(func, lhs.ctx(), {}, {ret.ptr_->var}); - } else { - Engine::Get()->Push(func, lhs.ctx(), {lhs.ptr_->var}, {ret.ptr_->var}); - } + Engine::Get()->Push([lhs, rhs, ret](RunContext ctx) { + ret.ptr_->CheckAndAlloc(); + TBlob tmp = ret.data(); + narray::Eval(lhs.data(), rhs, &tmp, ctx); + // Wait GPU kernel to complete + ctx.get_stream()->Wait(); + }, lhs.ctx(), const_vars, {ret.ptr_->var}); break; } #endif - default: LOG(FATAL) << "GPU is not enabled"; + default: LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR; } } @@ -175,48 +160,52 @@ void CopyFromTo(const NArray &from, NArray *to) { NArray ret = *to; int a = from.ctx().dev_mask; int b = to->ctx().dev_mask; + + std::vector const_vars; + if (from.ptr_->var != ret.ptr_->var) const_vars.push_back(from.ptr_->var); + if (a == cpu::kDevMask && b == cpu::kDevMask) { Engine::Get()->Push([from, ret](RunContext ctx) { ret.ptr_->CheckAndAlloc(); TBlob tmp = ret.data(); narray::Copy(from.data(), &tmp, from.ctx(), ret.ctx(), ctx); - }, from.ctx(), {from.ptr_->var}, {ret.ptr_->var}); - } else if (a == cpu::kDevMask && b == gpu::kDevMask) { -#if MXNET_USE_CUDA - Engine::Get()->Push([from, ret](RunContext ctx) { - ret.ptr_->CheckAndAlloc(); - TBlob tmp = ret.data(); - narray::Copy(from.data(), &tmp, - from.ctx(), ret.ctx(), ctx); - }, ret.ctx(), {from.ptr_->var}, {ret.ptr_->var}); -#else - LOG(FATAL) << "GPU is not enabled"; -#endif - } else if (a == gpu::kDevMask && b == cpu::kDevMask) { -#if MXNET_USE_CUDA - Engine::Get()->Push([from, ret](RunContext ctx) { - ret.ptr_->CheckAndAlloc(); - TBlob tmp = ret.data(); - narray::Copy(from.data(), &tmp, - from.ctx(), ret.ctx(), ctx); - }, from.ctx(), {from.ptr_->var}, {ret.ptr_->var}); -#else - LOG(FATAL) << "GPU is not enabled"; -#endif - } else if (a == gpu::kDevMask && b == gpu::kDevMask) { + }, from.ctx(), const_vars, {ret.ptr_->var}); + } else { #if MXNET_USE_CUDA - Engine::Get()->Push([from, ret](RunContext ctx) { - ret.ptr_->CheckAndAlloc(); - TBlob tmp = ret.data(); - narray::Copy(from.data(), &tmp, - from.ctx(), ret.ctx(), ctx); - }, from.ctx(), {from.ptr_->var}, {ret.ptr_->var}); + if (a == cpu::kDevMask && b == gpu::kDevMask) { + Engine::Get()->Push([from, ret](RunContext ctx) { + ret.ptr_->CheckAndAlloc(); + TBlob tmp = ret.data(); + narray::Copy(from.data(), &tmp, + from.ctx(), ret.ctx(), ctx); + // Wait GPU kernel to complete + ctx.get_stream()->Wait(); + }, ret.ctx(), const_vars, {ret.ptr_->var}); + } else if (a == gpu::kDevMask && b == cpu::kDevMask) { + Engine::Get()->Push([from, ret](RunContext ctx) { + ret.ptr_->CheckAndAlloc(); + TBlob tmp = ret.data(); + narray::Copy(from.data(), &tmp, + from.ctx(), ret.ctx(), ctx); + // Wait GPU kernel to complete + ctx.get_stream()->Wait(); + }, from.ctx(), const_vars, {ret.ptr_->var}); + } else if (a == gpu::kDevMask && b == gpu::kDevMask) { + Engine::Get()->Push([from, ret](RunContext ctx) { + ret.ptr_->CheckAndAlloc(); + TBlob tmp = ret.data(); + narray::Copy(from.data(), &tmp, + from.ctx(), ret.ctx(), ctx); + // Wait GPU kernel to complete + ctx.get_stream()->Wait(); + }, from.ctx(), const_vars, {ret.ptr_->var}); + } else { + LOG(FATAL) << "unknown device mask"; + } #else - LOG(FATAL) << "GPU is not enabled"; + LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR; #endif - } else { - LOG(FATAL) << "unknown device mask"; } } diff --git a/src/symbol/graph_executor.cc b/src/symbol/graph_executor.cc index 863355c28937..1211f1a4abb4 100644 --- a/src/symbol/graph_executor.cc +++ b/src/symbol/graph_executor.cc @@ -169,7 +169,6 @@ inline std::vector > GraphExecutor::GetInplaceOption( inline GraphExecutor::OpExecEntry GraphExecutor::GetOpExecEntry(uint32_t nid) { OpNode& op_node = op_nodes_[nid]; - Operator *op = op_node.op.get(); std::vector req; std::vector in_data, out_data, aux_states; in_data.reserve(graph_.nodes[nid].inputs.size()); @@ -199,14 +198,30 @@ GraphExecutor::GetOpExecEntry(uint32_t nid) { } } + // start setup exec function. + Operator* op = op_node.op.get(); OpContext* op_ctx_ptr = &op_node.op_ctx; - exec.exec_fun = [op, op_ctx_ptr, in_data, req, out_data, aux_states] (RunContext ctx) { + bool is_gpu = op_node.ctx.dev_mask == gpu::kDevMask; + exec.exec_fun = [op, is_gpu, op_ctx_ptr, in_data, req, out_data, aux_states] (RunContext ctx) { op_ctx_ptr->run_ctx = ctx; op->Forward(*op_ctx_ptr, in_data, req, out_data, aux_states); + if (is_gpu) { +#if MXNET_USE_CUDA + // Wait GPU kernel to finish. + ctx.get_stream()->Wait(); +#else + LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR; +#endif + } }; return exec; } +GraphExecutor::~GraphExecutor() { + // need to destruct after all previously issued operations are finished. + Engine::Get()->WaitForAll(); +} + void GraphExecutor::InitGraph(Symbol symbol, Context ctx, bool need_backward) { // initialize all internal data structures symbol.ToStaticGraph(&graph_); diff --git a/src/symbol/graph_executor.h b/src/symbol/graph_executor.h index 074fafa0c571..5db11fcdf779 100644 --- a/src/symbol/graph_executor.h +++ b/src/symbol/graph_executor.h @@ -19,7 +19,7 @@ namespace mxnet { */ class GraphExecutor : public Executor { public: - virtual ~GraphExecutor() {} + virtual ~GraphExecutor(); virtual void Forward(bool is_train); virtual void Backward(const std::vector &head_grads); virtual const std::vector &heads() const { diff --git a/src/symbol/graph_memory_allocator.h b/src/symbol/graph_memory_allocator.h index b7bd2db2081e..9c995cd29993 100644 --- a/src/symbol/graph_memory_allocator.h +++ b/src/symbol/graph_memory_allocator.h @@ -56,7 +56,7 @@ class GraphStorageAllocator { */ NArray Get(StorageID id, TShape shape); - private: + protected: /*! \brief internal storage entry */ struct StorageEntry { /*! \brief id of the storage */ From 04517ccbbf54fb8d3656528611848d85d3903dbd Mon Sep 17 00:00:00 2001 From: tqchen Date: Sat, 12 Sep 2015 16:49:36 -0700 Subject: [PATCH 05/13] Change engine callback to old style; support multiple streams in Naive. Make PushSync inline; Use registry and env variable to select engine move tests into fine grained folders checkin gtest fix test --- .gitignore | 1 + .travis.yml | 3 +- Makefile | 9 +- include/mxnet/base.h | 11 ++ include/mxnet/engine.h | 161 ++++++++++-------- include/mxnet/kvstore.h | 2 +- include/mxnet/narray.h | 2 +- include/mxnet/symbolic.h | 2 +- scripts/travis_script.sh | 23 ++- src/c_api.cc | 2 +- src/common/cuda_utils.h | 9 +- src/common/object_pool.h | 3 +- src/common/utils.h | 3 +- src/engine/engine.cc | 47 ++--- src/engine/engine_impl.h | 44 +++-- src/engine/naive_engine.cc | 146 +++++++++------- src/engine/naive_engine.h | 47 ----- src/engine/thread_pool.h | 1 - src/engine/threaded_engine.cc | 102 +++++------ src/engine/threaded_engine.h | 41 ++--- src/narray/narray.cc | 20 +-- src/symbol/graph_executor.cc | 14 +- src/symbol/graph_executor.h | 2 +- tests/.gitignore | 1 + tests/cpp/.gitignore | 1 + .../storage_unittest.cc} | 33 ++-- .../threaded_engine_unittest.cc} | 21 +-- tests/cpp/unittest.mk | 16 ++ tests/python/README.md | 10 ++ tests/python/{ => common}/get_data.py | 0 tests/python/{ => common}/models.py | 0 tests/python/train/common.py | 6 + tests/python/{ => train}/test_conv.py | 3 +- tests/python/{ => train}/test_mlp.py | 5 +- tests/python/unittest/common.py | 6 + tests/python/{ => unittest}/test_bind.py | 0 .../python/{ => unittest}/test_infer_shape.py | 2 +- tests/python/{ => unittest}/test_io.py | 5 +- tests/python/{ => unittest}/test_kvstore.py | 0 tests/python/{ => unittest}/test_narray.py | 0 tests/python/{ => unittest}/test_operator.py | 0 tests/python/{ => unittest}/test_symbol.py | 2 +- 42 files changed, 438 insertions(+), 368 deletions(-) delete mode 100644 src/engine/naive_engine.h create mode 100644 tests/cpp/.gitignore rename tests/{test_storage.cc => cpp/storage_unittest.cc} (54%) rename tests/{test_threaded_engine.cc => cpp/threaded_engine_unittest.cc} (87%) create mode 100644 tests/cpp/unittest.mk create mode 100644 tests/python/README.md rename tests/python/{ => common}/get_data.py (100%) rename tests/python/{ => common}/models.py (100%) create mode 100644 tests/python/train/common.py rename tests/python/{ => train}/test_conv.py (99%) rename tests/python/{ => train}/test_mlp.py (98%) create mode 100644 tests/python/unittest/common.py rename tests/python/{ => unittest}/test_bind.py (100%) rename tests/python/{ => unittest}/test_infer_shape.py (97%) rename tests/python/{ => unittest}/test_io.py (98%) rename tests/python/{ => unittest}/test_kvstore.py (100%) rename tests/python/{ => unittest}/test_narray.py (100%) rename tests/python/{ => unittest}/test_operator.py (100%) rename tests/python/{ => unittest}/test_symbol.py (97%) diff --git a/.gitignore b/.gitignore index 480e945bfc0b..56c1e88370c9 100644 --- a/.gitignore +++ b/.gitignore @@ -57,3 +57,4 @@ __pycache__ build dmlc-core mshadow +data diff --git a/.travis.yml b/.travis.yml index a02fbe658554..5c7a5d2562a6 100644 --- a/.travis.yml +++ b/.travis.yml @@ -11,7 +11,7 @@ env: - TASK=python CXX=g++ - TASK=python3 CXX=g++ - TASK=python_naive CXX=g++ - - TASK=unittest_gtest CXX=g++ + - TASK=cpp_unittest CXX=g++ # dependent apt packages addons: @@ -47,6 +47,7 @@ before_install: install: - pip install cpplint pylint --user `whoami` + - make -f dmlc-core/scripts/packages.mk gtest - if [ "$CXX" = "g++" ]; then export CXX="g++-4.8" CC="gcc-4.8"; fi script: diff --git a/Makefile b/Makefile index e5fc24acc97a..879f534b0a9c 100644 --- a/Makefile +++ b/Makefile @@ -84,7 +84,6 @@ endif .PHONY: clean all test lint doc -BIN = tests/test_threaded_engine all: lib/libmxnet.a lib/libmxnet.so $(BIN) SRC = $(wildcard src/*.cc src/*/*.cc) @@ -114,13 +113,13 @@ lib/libmxnet.a: $(ALL_DEP) lib/libmxnet.so: $(ALL_DEP) $(CXX) $(CFLAGS) -shared -o $@ $(filter %.o %.a, $^) $(LDFLAGS) -tests/% : tests/%.cc lib/libmxnet.a - $(CXX) -std=c++0x $(CFLAGS) -MM -MT tests/$*.o $< >tests/$*.d - $(CXX) $(CFLAGS) -std=c++0x -o $@ $(filter %.cc %.a, $^) $(LDFLAGS) - $(DMLC_CORE)/libdmlc.a: + cd $(DMLC_CORE); make libdmlc.a config=$(ROOTDIR)/$(config); cd $(ROOTDIR) +include tests/cpp/unittest.mk + +test: tests/cpp/unittest + lint: python dmlc-core/scripts/lint.py mxnet ${LINT_LANG} include src scripts python diff --git a/include/mxnet/base.h b/include/mxnet/base.h index 7e5b6d06f0d8..7f2fc7c07a0b 100644 --- a/include/mxnet/base.h +++ b/include/mxnet/base.h @@ -36,6 +36,17 @@ /*! \brief Error message for using gpu when MXNET_USE_CUDA==0 */ #define MXNET_GPU_NOT_ENABLED_ERROR "GPU is not enabled" +/*! + * \brief define compatible keywords in g++ + * Used to support g++-4.6 and g++4.7 + */ +#if DMLC_USE_CXX11 && defined(__GNUC__) && !defined(__clang_version__) +#if __GNUC__ == 4 && __GNUC_MINOR__ == 6 +#define override +#define final +#endif +#endif + /*! \brief namespace of mxnet */ namespace mxnet { /*! \brief mxnet cpu */ diff --git a/include/mxnet/engine.h b/include/mxnet/engine.h index 91b9f2a72b8d..f185da8215c3 100644 --- a/include/mxnet/engine.h +++ b/include/mxnet/engine.h @@ -1,72 +1,79 @@ /*! * Copyright (c) 2015 by Contributors * \file engine.h - * \brief Engine that schedules data. + * \brief Engine that schedules all the operations according to dependency. */ #ifndef MXNET_ENGINE_H_ #define MXNET_ENGINE_H_ -#include - -#if DMLC_USE_CXX11 == 0 -#error "C++11 was required for engine module." -#endif +#include +#if DMLC_USE_CXX11 #include +#endif #include -#include "base.h" -#include "context.h" +#include "./base.h" +#include "./context.h" namespace mxnet { - -/*! - * \brief Namespace of engine implementation. - */ +/*! \brief namespace of engine internal types. */ namespace engine { - -/*! - * \brief Inner representation of variable. - */ +/*! \brief Internal representation of variable. */ struct Var; - -/*! - * \brief Inner representation of operator. - */ +/*! \brief Internal representation of operator. */ struct Opr; - +/*! \brief Variable pointer type, usually hold by user used to specify dependencies. */ +typedef Var* VarHandle; +/*! \brief Operator pointer type, usually hold by user.*/ +typedef Opr* OprHandle; } // namespace engine -/*! - * \brief Function property. - */ -enum class FnProperty { kNormal, kIO, kAsync }; // enum class FnProperty +#if DMLC_USE_CXX11 + +/*! \brief Function property, used to hint what action is pushed to engine. */ +enum class FnProperty { + /*! \brief Normal operation */ + kNormal, + /*! \brief Copy operation between CPU and GPU */ + kCopy, + /*! \brief Asynchronous function call */ + kAsync +}; // enum class FnProperty /*! - * \brief Dynamic dataflow engine that schedules operations. - */ + * \brief Dependency engine that schedules operations. +*/ class Engine { public: /*! - * \brief Operation to pass to engine. - */ - using Fn = std::function; - /*! - * \brief Callback function to notify operation complete. - */ - using Callback = std::function; - /*! - * \brief Asynchronous operation to pass to engine. - */ - using AsyncFn = std::function; - /*! - * \brief Variable of engine, used to specify dependencies defined to be a - * pointer, that points to an internal data structure of the engine - * itself. - */ - using VarHandle = engine::Var*; - /*! - * \brief Operator of the engine. + * \brief OnComplete Callback to the engine, + * called by AsyncFn when action completes */ - using OprHandle = engine::Opr*; + class CallbackOnComplete { + public: + // use implicit copy and assign + /*! \brief involve the callback */ + inline void operator()() const { + (*callback_)(engine_, param_); + } + + private: + /*! \brief engine can see content of callback */ + friend class ::mxnet::Engine; + /*! \brief the real callback */ + void (*callback_)(Engine *, void *); + /*! \brief the engine class passed to callback */ + Engine* engine_; + /*! \brief the parameter set on callback */ + void* param_; + }; + /*! \brief Synchronous operation to pass to engine. */ + typedef std::function SyncFn; + /*! \brief Asynchronous operation to pass to engine. */ + typedef std::function AsyncFn; + /*! \brief Variable pointer */ + typedef engine::VarHandle VarHandle; + /*! \brief Operator pointer */ + typedef engine::OprHandle OprHandle; /*! * \brief Allocate a new variable, the variable can then * be used to schedule the operation concurrently via dependency @@ -102,19 +109,6 @@ class Engine { * \param exec_ctx Execution context. */ virtual void Push(OprHandle op, Context exec_ctx) = 0; - /*! - * \brief Push an synchronous operation to the engine. - * \param exec_fun Execution function that executes the operation. - * \param exec_ctx Execution context. - * \param const_vars The variables that current operation will use but not - * mutate. - * \param mutable_vars The variables that current operation will mutate. - * \param prop Property of the function. - */ - virtual void Push(Fn exec_fun, Context exec_ctx, - std::vector const& const_vars, - std::vector const& mutable_vars, - FnProperty prop = FnProperty::kNormal) = 0; /*! * \brief Push an asynchronous operation to the engine. * \param exec_fun Execution function, this function takes a parameter @@ -141,7 +135,8 @@ class Engine { * \param exec_ctx Execution context. * \param var The variable to be deleted. */ - virtual void DeleteVariable(Fn delete_fun, Context exec_ctx, + virtual void DeleteVariable(SyncFn delete_fn, + Context exec_ctx, VarHandle var) = 0; /*! * \brief Wait for a variable. @@ -153,16 +148,48 @@ class Engine { * \brief Wait until all the activity of engine finishes. */ virtual void WaitForAll() = 0; - /*! - * \brief Virtual destructor. - */ - virtual ~Engine() noexcept(false); + /*!\brief virtual destructor */ + virtual ~Engine() noexcept(false) {} /*! * \return Engine singleton. */ static Engine* Get(); -}; // class Engine + /*! + * \brief Push an synchronous operation to the engine. + * \param exec_fn Execution function that executes the operation. + * \param exec_ctx Execution context. + * \param const_vars The variables that current operation will use but not + * mutate. + * \param mutable_vars The variables that current operation will mutate. + * \param prop Property of the function. + * \tparam SyncFn the synchronous function to be pushed. + */ + template + inline void PushSync(SyncFn exec_fn, Context exec_ctx, + std::vector const& const_vars, + std::vector const& mutable_vars, + FnProperty prop = FnProperty::kNormal) { + this->PushAsync([exec_fn](RunContext ctx, CallbackOnComplete on_complete) { + exec_fn(ctx); + on_complete(); + }, exec_ctx, const_vars, mutable_vars, prop); + } + protected: + /*! + * \brief factory function to create OnComplete callback. + * \param callback th static callback function. + * \param param the paramter passed to callback. + */ + inline CallbackOnComplete CreateCallback( + void (*callback)(Engine *, void *), void *param) { + CallbackOnComplete ret; + ret.callback_ = callback; + ret.engine_ = this; + ret.param_ = param; + return ret; + } +}; // class Engine +#endif // DMLC_USE_CXX11 } // namespace mxnet - #endif // MXNET_ENGINE_H_ diff --git a/include/mxnet/kvstore.h b/include/mxnet/kvstore.h index 43c14a410a07..ef4be5102578 100644 --- a/include/mxnet/kvstore.h +++ b/include/mxnet/kvstore.h @@ -118,7 +118,7 @@ class KVStore { /** * \brief the prototype of user-defined updater */ - using Updater = std::function; + typedef std::function Updater; /*! \brief returns the default updater, which is ASSIGN */ Updater DefaultUpdater() { diff --git a/include/mxnet/narray.h b/include/mxnet/narray.h index fb8fc2b7484d..20372524b3ac 100644 --- a/include/mxnet/narray.h +++ b/include/mxnet/narray.h @@ -91,7 +91,7 @@ class NArray { * Push an empty mutable function to flush all preceding reads to the * variable. */ - Engine::Get()->Push([](RunContext) {}, Context{}, {}, {ptr_->var}); + Engine::Get()->PushSync([](RunContext) {}, Context{}, {}, {ptr_->var}); Engine::Get()->WaitForVar(ptr_->var); } /*! \return the associated variable of the narray.*/ diff --git a/include/mxnet/symbolic.h b/include/mxnet/symbolic.h index ef9e562f64bc..28e82da32c06 100644 --- a/include/mxnet/symbolic.h +++ b/include/mxnet/symbolic.h @@ -399,7 +399,7 @@ class Executor { * * \param head_grads the gradient of head nodes to be backproped. */ - virtual void Backward(const std::vector &head_grads = {}) = 0; + virtual void Backward(const std::vector &head_grads) = 0; /*! * \brief get array of heads in the executor. * \return array of heads in the executor. diff --git a/scripts/travis_script.sh b/scripts/travis_script.sh index 4b52c354df19..99d1771d1ac7 100755 --- a/scripts/travis_script.sh +++ b/scripts/travis_script.sh @@ -29,22 +29,35 @@ fi if [ ${TASK} == "python" ]; then echo "USE_CUDA=0" >> config.mk - echo "USE_THREADED_ENGINE=1" >> config.mk make all || exit -1 - nosetests tests/python || exit -1 + export MXNET_ENGINE_TYPE=ThreadedEngine + nosetests tests/python/unittest || exit -1 + nosetests tests/python/train || exit -1 fi if [ ${TASK} == "python3" ]; then echo "USE_CUDA=0" >> config.mk - echo "USE_THREADED_ENGINE=1" >> config.mk make all || exit -1 - nosetests3 tests/python || exit -1 + export MXNET_ENGINE_TYPE=ThreadedEngine + nosetests tests/python/unittest || exit -1 + nosetests tests/python/train || exit -1 fi if [ ${TASK} == "python_naive" ]; then echo "USE_CUDA=0" >> config.mk make all || exit -1 - nosetests tests/python || exit -1 + export MXNET_ENGINE_TYPE=NaiveEngine + nosetests tests/python/unittest || exit -1 + nosetests tests/python/train || exit -1 +fi + +if [ ${TASK} == "cpp_unittest" ]; then + echo "USE_CUDA=0" >> config.mk + make test || exit -1 + export MXNET_ENGINE_TYPE=NaiveEngine + testsp/cpp/unittest || exit -1 + export MXNET_ENGINE_TYPE=ThreadedEngine + tests/cpp/unittest || exit -1 fi # TODO(yutian): add unittest back diff --git a/src/c_api.cc b/src/c_api.cc index 586ea36b166a..7ce3dac7cf61 100644 --- a/src/c_api.cc +++ b/src/c_api.cc @@ -265,7 +265,7 @@ int MXNArrayWaitToWrite(NArrayHandle handle) { API_END(); } -const int kMXAPINArrayListMagic = 0x112; +const uint64_t kMXAPINArrayListMagic = 0x112; int MXNArrayListSave(const char* fname, mx_uint num_args, diff --git a/src/common/cuda_utils.h b/src/common/cuda_utils.h index 6002da20c1fe..51e67bfb0d04 100644 --- a/src/common/cuda_utils.h +++ b/src/common/cuda_utils.h @@ -14,13 +14,10 @@ #include #include +namespace mxnet { namespace common { - -/*! - * \brief CUDA utilities. - */ +/*! \brief common utils for cuda */ namespace cuda { - /*! * \brief Get string representation of cuBLAS errors. * \param error The error. @@ -91,6 +88,7 @@ inline const char* CurandGetErrorString(curandStatus_t status) { } // namespace cuda } // namespace common +} // namespace mxnet /*! * \brief Check CUDA error. @@ -153,5 +151,4 @@ inline const char* CurandGetErrorString(curandStatus_t status) { } #endif // MXNET_USE_CUDNN - #endif // MXNET_COMMON_CUDA_UTILS_H_ diff --git a/src/common/object_pool.h b/src/common/object_pool.h index 052688ce601a..2e38654f7a4c 100644 --- a/src/common/object_pool.h +++ b/src/common/object_pool.h @@ -9,8 +9,8 @@ #include #include +namespace mxnet { namespace common { - /*! * \brief Object pool for fast allocation and deallocation. */ @@ -172,4 +172,5 @@ void ObjectPoolAllocatable::Delete(T* ptr) { } } // namespace common +} // namespace mxnet #endif // MXNET_COMMON_OBJECT_POOL_H_ diff --git a/src/common/utils.h b/src/common/utils.h index 29cb9f0e2f2a..ffa5c349c65c 100644 --- a/src/common/utils.h +++ b/src/common/utils.h @@ -13,6 +13,7 @@ #include #endif // DMLC_USE_CXX11 +namespace mxnet { namespace common { #if DMLC_USE_CXX11 @@ -112,5 +113,5 @@ typename helper::UniqueIf::KnownBound MakeUnique(Args&&... args) = delete; #endif // DMLC_USE_CXX11 } // namespace common - +} // namespace mxnet #endif // MXNET_COMMON_UTILS_H_ diff --git a/src/engine/engine.cc b/src/engine/engine.cc index 4047099d14cc..75bb58c6f56a 100644 --- a/src/engine/engine.cc +++ b/src/engine/engine.cc @@ -1,28 +1,37 @@ /*! * Copyright (c) 2015 by Contributors + * \file engine.cc + * \brief Implementation of engine. */ -#include "mxnet/engine.h" -#include "engine_impl.h" -#include "naive_engine.h" -#include "threaded_engine.h" +#include +#include +#include +#include "./engine_impl.h" namespace mxnet { - -Engine::~Engine() noexcept(false) {} +namespace engine { +inline Engine* CreateEngine() { + const char *type = getenv("MXNET_ENGINE_TYPE"); + const bool default_engine = (type == nullptr); + if (type == nullptr) type = "ThreadedEngine"; + std::string stype = type; + Engine *ret = nullptr; + if (stype == "ThreadedEngine") { + ret = CreateThreadedEngine(); + } else if (stype == "NaiveEngine") { + ret = CreateNaiveEngine(); + } + CHECK_NE(ret, nullptr) + << "Cannot find Eine " << type << " in registry"; + if (!default_engine) { + LOG(INFO) << "MXNet start using engine: " << type; + } + return ret; +} +} // namespace engine Engine* Engine::Get() { - /*! - * \brief Change specific engine to use. - */ -#ifdef MXNET_USE_THREADED_ENGINE - using EngineImplementation = engine::ThreadedEngine; -#else // MXNET_USE_THREADED_ENGINE -#warning "Using naive engine."; - using EngineImplementation = engine::NaiveEngine; -#endif // MXNET_USE_THREADED_ENGINE - - static EngineImplementation inst; - return &inst; + static std::unique_ptr inst(engine::CreateEngine()); + return inst.get(); } - } // namespace mxnet diff --git a/src/engine/engine_impl.h b/src/engine/engine_impl.h index c1ebe2a042f1..cc5ab2e47d6a 100644 --- a/src/engine/engine_impl.h +++ b/src/engine/engine_impl.h @@ -1,56 +1,74 @@ /*! * Copyright (c) 2015 by Contributors + * \file engine_impl.h + * \brief Internal implementation header of engine components. */ #ifndef MXNET_ENGINE_ENGINE_IMPL_H_ #define MXNET_ENGINE_ENGINE_IMPL_H_ -#include -#include "mxnet/engine.h" +#include +/*! \brief MACRO on whether or not enable debug option*/ #define ENGINE_DEBUG 0 namespace mxnet { namespace engine { - +/*! \brief base class of engine variables, used for type checking */ struct Var { #if ENGINE_DEBUG virtual ~Var() = default; #endif // ENGINE_DEBUG + /*! + * \brief cast variable to derived type T + * \tparam T the type we want to cast into. + * \return A casted variable. + */ template - T* Cast(); + inline T* Cast(); }; // struct Var +/*! \brief base class of engine operators, used for type checking */ struct Opr { #if ENGINE_DEBUG virtual ~Opr() = default; -#endif // ENGINE_DEBUG +#endif + /*! + * \brief cast variable to derived type T + * \tparam T the type we want to cast into. + * \return A casted variable. + */ template - T* Cast(); + inline T* Cast(); }; // struct Opr +// implementation of the inline functions template -T* Var::Cast() { +inline T* Var::Cast() { static_assert(std::is_base_of::value, "must inherit `mxnet::engine::Var`"); #if ENGINE_DEBUG return dynamic_cast(this); -#else // ENGINE_DEBUG +#else return static_cast(this); -#endif // ENGINE_DEBUG +#endif } template -T* Opr::Cast() { +inline T* Opr::Cast() { static_assert(std::is_base_of::value, "must inherit `mxnet::engine::Opr`"); #if ENGINE_DEBUG return dynamic_cast(this); -#else // ENGINE_DEBUG +#else return static_cast(this); -#endif // ENGINE_DEBUG +#endif } +// predeclare factory function for each type of engine +/*! \return NaiveEngine instance */ +Engine *CreateNaiveEngine(); +/*! \return ThreadedEngine instance */ +Engine *CreateThreadedEngine(); } // namespace engine } // namespace mxnet - #endif // MXNET_ENGINE_ENGINE_IMPL_H_ diff --git a/src/engine/naive_engine.cc b/src/engine/naive_engine.cc index 7f38fd92de35..a8c7db598319 100644 --- a/src/engine/naive_engine.cc +++ b/src/engine/naive_engine.cc @@ -1,75 +1,103 @@ /*! - * Copyright (c) 2015 by Contributors + * Copyright (c) 2015 by Contributors + * \file naive_engine.cc + * \brief Implementation of NaiveEngine */ -#include "naive_engine.h" #include +#include "./engine_impl.h" namespace mxnet { namespace engine { +// implement naive engine +class NaiveEngine final : public Engine { + public: + NaiveEngine() { + } + // virtual destructor + virtual ~NaiveEngine() { +#if MXNET_USE_CUDA + for (size_t i = 0; i < streams_.size(); ++i) { + if (streams_[i] != nullptr) { + mshadow::DeleteStream(streams_[i]); + } + } +#endif + } + // new variables + VarHandle NewVariable() override { + return nullptr; + } + OprHandle NewOperator(AsyncFn fn, + std::vector const& const_vars, + std::vector const& mutable_vars, + FnProperty prop) override { + LOG(FATAL) << "Not implemented"; + return nullptr; + } + void DeleteOperator(OprHandle op) override { + LOG(FATAL) << "Not implemented"; + } + void Push(OprHandle op, Context exec_ctx) override { + LOG(FATAL) << "Not implemented"; + } + void PushAsync(AsyncFn exec_fun, + Context exec_ctx, + std::vector const& const_vars, + std::vector const& mutable_vars, + FnProperty prop) override { + CallbackOnComplete callback = CreateCallback( + NaiveEngine::OnComplete, nullptr); + this->req_completed_ = false; -NaiveEngine::VarHandle NaiveEngine::NewVariable() { return nullptr; } - -NaiveEngine::NaiveEngine() { - #if MXNET_USE_CUDA - #if MXNET_USE_CUDNN == 1 - LOG(INFO) << "MXNet is using CuDNN for Convolution, Pooling Op"; - stream_ = mshadow::NewStream(true, true); - #else - stream_ = mshadow::NewStream(true, false); - #endif // MXNET_USE_CUDNN - ctx_.stream = stream_; - #endif // MXNET_USE_CUDA -} - -NaiveEngine::~NaiveEngine() { - #if MXNET_USE_CUDA - mshadow::DeleteStream(stream_); - #endif -} - -NaiveEngine::OprHandle NaiveEngine::NewOperator(AsyncFn, - std::vector const&, - std::vector const&, - FnProperty) { - LOG(FATAL) << "Not implemented"; - return nullptr; -} - -void NaiveEngine::DeleteOperator(OprHandle) { LOG(FATAL) << "Not implemented"; } - -void NaiveEngine::Push(OprHandle, Context) { LOG(FATAL) << "Not implemented"; } - -void NaiveEngine::Push(Fn exec_fun, Context exec_ctx, - std::vector const&, - std::vector const&, FnProperty) { - if (exec_ctx.dev_mask == gpu::kDevMask) { + if (exec_ctx.dev_mask == gpu::kDevMask) { #if MXNET_USE_CUDA - mshadow::SetDevice(exec_ctx.dev_id); - ctx_.stream = stream_; - exec_fun(ctx_); - stream_->Wait(); + size_t dev_id = static_cast(exec_ctx.dev_id); + mshadow::SetDevice(exec_ctx.dev_id); + if (streams_.size() <= dev_id) { + streams_.resize(dev_id + 1, nullptr); + } + if (streams_[dev_id] == nullptr) { + streams_[dev_id] = mshadow::NewStream(true, MXNET_USE_CUDNN != 0); + } + ctx_.stream = streams_[dev_id]; + exec_fun(ctx_, callback); + streams_[dev_id]->Wait(); #else - LOG(FATAL) << "GPU is not enabled"; + LOG(FATAL) << "GPU is not enabled"; #endif - } else { - exec_fun(ctx_); + } else { + ctx_.stream = &cpu_stream_; + exec_fun(ctx_, callback); + } + CHECK(this->req_completed_) + << "NaiveEngine only support synchronize Push so far"; + } + void DeleteVariable(SyncFn delete_fn, Context exec_ctx, VarHandle var) override { + this->PushSync(delete_fn, exec_ctx, {}, {var}, FnProperty::kNormal); + } + void WaitForVar(VarHandle var) override { + } + void WaitForAll() override { } -} - -void NaiveEngine::PushAsync(AsyncFn, Context, std::vector const&, - std::vector const&, FnProperty) { - LOG(FATAL) << "Not implemented"; -} - -void NaiveEngine::DeleteVariable(Fn delete_fun, Context exec_ctx, - VarHandle var) { - this->Push(delete_fun, exec_ctx, {}, {var}, FnProperty::kNormal); -} -void NaiveEngine::WaitForVar(VarHandle) {} + private: + // callback to oncomplete + static void OnComplete(Engine *engine, void *param) { + static_cast(engine)->req_completed_ = true; + } + // runtime contetxt + RunContext ctx_; + // whether action is completed + bool req_completed_; + // CPU stream + mshadow::Stream cpu_stream_; + // GPU streams + std::vector*> streams_; +}; // class NaiveEngine -void NaiveEngine::WaitForAll() {} +Engine *CreateNaiveEngine() { + return new NaiveEngine(); +} } // namespace engine } // namespace mxnet - diff --git a/src/engine/naive_engine.h b/src/engine/naive_engine.h deleted file mode 100644 index bbcbc9d2c215..000000000000 --- a/src/engine/naive_engine.h +++ /dev/null @@ -1,47 +0,0 @@ -/*! - * Copyright (c) 2015 by Contributors - */ -#ifndef MXNET_ENGINE_NAIVE_ENGINE_H_ -#define MXNET_ENGINE_NAIVE_ENGINE_H_ - -#include -#include "engine_impl.h" - -namespace mxnet { - -namespace engine { - -class NaiveEngine final : public Engine { - public: - NaiveEngine(); - ~NaiveEngine(); - VarHandle NewVariable() override; - OprHandle NewOperator(AsyncFn fn, std::vector const& const_vars, - std::vector const& mutable_vars, - FnProperty prop) override; - void DeleteOperator(OprHandle op) override; - void Push(OprHandle op, Context exec_ctx) override; - void Push(Fn exec_fun, Context exec_ctx, - std::vector const& const_vars, - std::vector const& mutable_vars, - FnProperty prop) override; - void PushAsync(AsyncFn exec_fun, Context exec_ctx, - std::vector const& const_vars, - std::vector const& mutable_vars, - FnProperty prop) override; - void DeleteVariable(Fn delete_fun, Context exec_ctx, VarHandle var) override; - void WaitForVar(VarHandle var) override; - void WaitForAll() override; - - private: - RunContext ctx_; -#if MXNET_USE_CUDA - mshadow::Stream* stream_; -#endif // MXNET_USE_CUDA -}; // class NaiveEngine - -} // namespace engine - -} // namespace mxnet - -#endif // MXNET_ENGINE_NAIVE_ENGINE_H_ diff --git a/src/engine/thread_pool.h b/src/engine/thread_pool.h index 292b6c433d45..ef99a93e58d1 100644 --- a/src/engine/thread_pool.h +++ b/src/engine/thread_pool.h @@ -12,7 +12,6 @@ #include "mxnet/base.h" namespace mxnet { - namespace engine { /*! diff --git a/src/engine/threaded_engine.cc b/src/engine/threaded_engine.cc index cd9758835346..b4da2f4b06a9 100644 --- a/src/engine/threaded_engine.cc +++ b/src/engine/threaded_engine.cc @@ -11,7 +11,6 @@ #include "../common/cuda_utils.h" namespace mxnet { - namespace engine { #if ENGINE_DEBUG @@ -142,8 +141,7 @@ ThreadedEngine::~ThreadedEngine() noexcept(false) { } ThreadedVar* ThreadedEngine::NewVariable() { - auto ret = ThreadedVar::New(VersionedVarBlock::New()); - return ret; + return ThreadedVar::New(VersionedVarBlock::New()); } ThreadedOpr* ThreadedEngine::NewOperator( @@ -195,16 +193,18 @@ ThreadedOpr* ThreadedEngine::NewOperator( void ThreadedEngine::DeleteOperator(OprHandle op) { auto&& threaded_opr = ThreadedOpr::CastFromBase(op); - std::vector deps{}; + std::vector deps; deps.reserve(threaded_opr->const_vars.size() + threaded_opr->mutable_vars.size()); - deps.insert(deps.end(), threaded_opr->const_vars.begin(), + deps.insert(deps.end(), + threaded_opr->const_vars.begin(), threaded_opr->const_vars.end()); - deps.insert(deps.end(), threaded_opr->mutable_vars.begin(), + deps.insert(deps.end(), + threaded_opr->mutable_vars.begin(), threaded_opr->mutable_vars.end()); - auto&& func = - [threaded_opr](RunContext) { ThreadedOpr::Delete(threaded_opr); }; - Push(func, Context{}, {}, deps, FnProperty::kAsync); + this->PushSync([threaded_opr](RunContext) { + ThreadedOpr::Delete(threaded_opr); + }, Context(), {}, deps, FnProperty::kAsync); } void ThreadedEngine::Push(OprHandle op, Context exec_ctx) { @@ -232,17 +232,6 @@ void ThreadedEngine::Push(OprHandle op, Context exec_ctx) { } } -void ThreadedEngine::Push(Fn exec_fun, Context exec_ctx, - std::vector const& const_vars, - std::vector const& mutable_vars, - FnProperty prop) { - auto f = [exec_fun](RunContext ctx, Callback on_complete) { - exec_fun(ctx); - on_complete(); - }; - PushAsync(f, exec_ctx, const_vars, mutable_vars, prop); -} - void ThreadedEngine::PushAsync(AsyncFn fn, Context exec_ctx, std::vector const& const_vars, std::vector const& mutable_vars, @@ -252,34 +241,29 @@ void ThreadedEngine::PushAsync(AsyncFn fn, Context exec_ctx, Push(opr, exec_ctx); } -void ThreadedEngine::DeleteVariable(Fn delete_fn, Context exec_ctx, +void ThreadedEngine::DeleteVariable(SyncFn delete_fn, + Context exec_ctx, VarHandle var) { - auto&& threaded_var = ThreadedVar::CastFromBase(var); - auto&& func = [delete_fn, threaded_var](RunContext ctx) { - /*! - * Mark variable as orphan, so during `ThreadedEngine::OnComplete` it could - * be recycled. - */ - threaded_var->SetToDelete(); - delete_fn(ctx); - }; - Push(func, exec_ctx, {}, {var}, FnProperty::kAsync); + ThreadedVar* threaded_var = ThreadedVar::CastFromBase(var); + this->PushSync([delete_fn, threaded_var](RunContext ctx) { + // Mark variable as orphan, + // so during `ThreadedEngine::OnComplete` it could be recycled. + threaded_var->SetToDelete(); + delete_fn(ctx); + }, exec_ctx, {}, {var}, FnProperty::kAsync); } void ThreadedEngine::WaitForVar(VarHandle var) { - auto&& threaded_var = ThreadedVar::CastFromBase(var); - if (threaded_var->ready_to_read()) { - return; - } + ThreadedVar* threaded_var = ThreadedVar::CastFromBase(var); + if (threaded_var->ready_to_read()) return; { std::unique_lock lock{finished_m_}; std::atomic done{false}; - auto&& callback = [this, &done](RunContext) { - std::unique_lock lock{finished_m_}; - done.store(true); - finished_cv_.notify_all(); - }; - Push(callback, Context{}, {var}, {}, FnProperty::kNormal); + this->PushSync([this, &done](RunContext) { + std::unique_lock lock{finished_m_}; + done.store(true); + finished_cv_.notify_all(); + }, Context{}, {var}, {}, FnProperty::kNormal); finished_cv_.wait(lock, [&done]() { return done.load(); }); } } @@ -289,16 +273,12 @@ void ThreadedEngine::WaitForAll() { finished_cv_.wait(lock, [this]() { return pending_.load() == 0; }); } -void ThreadedEngine::OnComplete(ThreadedOpr* threaded_opr) { - /*! - * Mark complete for read variables. - */ +inline void ThreadedEngine::OnComplete(ThreadedOpr* threaded_opr) { + // Mark complete for read variables for (auto&& i : threaded_opr->const_vars) { i->CompleteReadDependency([this](OprBlock* opr) { DoPushToQueue(opr); }); } - /*! - * Mark complete for write variables. - */ + // Mark complete for write variables. for (auto&& i : threaded_opr->mutable_vars) { bool to_delete = i->CompleteWriteDependency( [this](OprBlock* opr) { DoPushToQueue(opr); }); @@ -312,6 +292,10 @@ void ThreadedEngine::OnComplete(ThreadedOpr* threaded_opr) { finished_cv_.notify_all(); } } + // delte operator if it is temperory + if (threaded_opr->temporary) { + ThreadedOpr::Delete(threaded_opr); + } } void ThreadedEngine::ThreadWorker( @@ -324,24 +308,20 @@ void ThreadedEngine::ThreadWorker( void ThreadedEngine::DoPushToQueue(OprBlock* opr_block) { switch (opr_block->opr->prop) { - case FnProperty::kIO: + case FnProperty::kCopy: { io_task_queue_.Push(opr_block); break; - default: + } + default: { task_queue_.Push(opr_block); break; + } } } void ThreadedEngine::DoExecute(OprBlock* opr_block) { assert(opr_block->wait.load() == 0); - auto threaded_opr = opr_block->opr; - auto callback = [this, threaded_opr]() { - OnComplete(threaded_opr); - if (threaded_opr->temporary) { - ThreadedOpr::Delete(threaded_opr); - } - }; + ThreadedOpr* threaded_opr = opr_block->opr; if (opr_block->ctx.dev_mask == gpu::kDevMask) { #if MXNET_USE_CUDA CUDA_CALL(cudaSetDevice(opr_block->ctx.dev_id)); @@ -349,13 +329,17 @@ void ThreadedEngine::DoExecute(OprBlock* opr_block) { LOG(FATAL) << "Please compile with CUDA enabled"; #endif // MXNET_USE_CUDA } - auto&& rctx = opr_block->opr->prop == FnProperty::kIO + auto&& rctx = opr_block->opr->prop == FnProperty::kCopy ? streams_.GetIORunContext(opr_block->ctx) : streams_.GetRunContext(opr_block->ctx); + CallbackOnComplete callback = this->CreateCallback( + ThreadedEngine::OnComplete_, threaded_opr); threaded_opr->fn(rctx, callback); OprBlock::Delete(opr_block); } +Engine *CreateThreadedEngine() { + return new ThreadedEngine(); +} } // namespace engine - } // namespace mxnet diff --git a/src/engine/threaded_engine.h b/src/engine/threaded_engine.h index e2f5835b6507..9f3ae3f1c9ba 100644 --- a/src/engine/threaded_engine.h +++ b/src/engine/threaded_engine.h @@ -12,13 +12,12 @@ #include #include #include -#include "engine_impl.h" -#include "thread_pool.h" -#include "stream_manager.h" +#include "./engine_impl.h" +#include "./thread_pool.h" +#include "./stream_manager.h" #include "../common/object_pool.h" namespace mxnet { - namespace engine { /*! @@ -131,23 +130,13 @@ class ThreadedEngine final : public Engine { FnProperty prop) override; void DeleteOperator(OprHandle op) override; void Push(OprHandle op, Context exec_ctx) override; - void Push(Fn exec_fun, Context exec_ctx, - std::vector const& const_vars, - std::vector const& mutable_vars, - FnProperty prop) override; void PushAsync(AsyncFn exec_fun, Context exec_ctx, std::vector const& const_vars, std::vector const& mutable_vars, FnProperty prop) override; - void DeleteVariable(Fn delete_fn, Context exec_ctx, VarHandle var) override; + void DeleteVariable(SyncFn delete_fn, Context exec_ctx, VarHandle var) override; void WaitForVar(VarHandle var) override; void WaitForAll() override; - /*! - * \brief Callback on operation completion. - * - * On operation completion, this will trigger subsequent operations. - */ - void OnComplete(ThreadedOpr* threaded_opr); /*! * \brief Worker. * \param task_queue Queue to work on. @@ -155,16 +144,24 @@ class ThreadedEngine final : public Engine { * The method to pass to thread pool to parallelize. */ void ThreadWorker(dmlc::ConcurrentBlockingQueue* task_queue); - - private: /*! - * \brief Concurrency for thread pool. + * \brief Callback on operation completion. + * + * On operation completion, this will trigger subsequent operations. */ + inline void OnComplete(ThreadedOpr* threaded_opr); + // callback to the threaded engine + inline static void OnComplete_(Engine *engine, void *threaded_opr) { + static_cast(engine)->OnComplete( + static_cast(threaded_opr)); + } + + private: + /*! \brief Concurrency for thread pool */ static constexpr std::size_t kNumWorkingThreads = 16; - /*! - * \brief Constants for runtime context. - */ + /*! \brief Maximum number of GPUs */ static constexpr std::size_t kMaxNumGpus = 16; + /*!\brief number of streams allocated for each GPU */ static constexpr std::size_t kNumStreamsPerGpu = 16; /*! * \brief Number of pending operations. @@ -206,7 +203,5 @@ class ThreadedEngine final : public Engine { }; // class ThreadedEngine } // namespace engine - } // namespace mxnet - #endif // MXNET_ENGINE_THREADED_ENGINE_H_ diff --git a/src/narray/narray.cc b/src/narray/narray.cc index 22097eb46e86..661e2004079a 100644 --- a/src/narray/narray.cc +++ b/src/narray/narray.cc @@ -51,7 +51,7 @@ inline void BinaryOp(const NArray &lhs, // redirect everything to mshadow operations switch (lhs.ctx().dev_mask) { case cpu::kDevMask: { - Engine::Get()->Push([lhs, rhs, ret](RunContext ctx) { + Engine::Get()->PushSync([lhs, rhs, ret](RunContext ctx) { ret.ptr_->CheckAndAlloc(); TBlob tmp = ret.data(); narray::Eval(lhs.data(), rhs.data(), &tmp, ctx); @@ -60,7 +60,7 @@ inline void BinaryOp(const NArray &lhs, } #if MXNET_USE_CUDA case gpu::kDevMask: { - Engine::Get()->Push([lhs, rhs, ret](RunContext ctx) { + Engine::Get()->PushSync([lhs, rhs, ret](RunContext ctx) { ret.ptr_->CheckAndAlloc(); TBlob tmp = ret.data(); narray::Eval(lhs.data(), rhs.data(), &tmp, ctx); @@ -80,7 +80,7 @@ inline void SetValueOp(const real_t &rhs, NArray *out) { NArray ret = *out; switch (ret.ctx().dev_mask) { case cpu::kDevMask: { - Engine::Get()->Push([rhs, ret](RunContext ctx) { + Engine::Get()->PushSync([rhs, ret](RunContext ctx) { ret.ptr_->CheckAndAlloc(); TBlob tmp = ret.data(); narray::Eval(rhs, &tmp, ctx); @@ -89,7 +89,7 @@ inline void SetValueOp(const real_t &rhs, NArray *out) { } #if MXNET_USE_CUDA case gpu::kDevMask: { - Engine::Get()->Push([rhs, ret](RunContext ctx) { + Engine::Get()->PushSync([rhs, ret](RunContext ctx) { ret.ptr_->CheckAndAlloc(); TBlob tmp = ret.data(); narray::Eval(rhs, &tmp, ctx); @@ -128,7 +128,7 @@ inline void ScalarOp(const NArray &lhs, // redirect everything to mshadow operations switch (lhs.ctx().dev_mask) { case cpu::kDevMask: { - Engine::Get()->Push([lhs, rhs, ret](RunContext ctx) { + Engine::Get()->PushSync([lhs, rhs, ret](RunContext ctx) { ret.ptr_->CheckAndAlloc(); TBlob tmp = ret.data(); narray::Eval(lhs.data(), rhs, &tmp, ctx); @@ -137,7 +137,7 @@ inline void ScalarOp(const NArray &lhs, } #if MXNET_USE_CUDA case gpu::kDevMask: { - Engine::Get()->Push([lhs, rhs, ret](RunContext ctx) { + Engine::Get()->PushSync([lhs, rhs, ret](RunContext ctx) { ret.ptr_->CheckAndAlloc(); TBlob tmp = ret.data(); narray::Eval(lhs.data(), rhs, &tmp, ctx); @@ -165,7 +165,7 @@ void CopyFromTo(const NArray &from, NArray *to) { if (from.ptr_->var != ret.ptr_->var) const_vars.push_back(from.ptr_->var); if (a == cpu::kDevMask && b == cpu::kDevMask) { - Engine::Get()->Push([from, ret](RunContext ctx) { + Engine::Get()->PushSync([from, ret](RunContext ctx) { ret.ptr_->CheckAndAlloc(); TBlob tmp = ret.data(); narray::Copy(from.data(), &tmp, @@ -174,7 +174,7 @@ void CopyFromTo(const NArray &from, NArray *to) { } else { #if MXNET_USE_CUDA if (a == cpu::kDevMask && b == gpu::kDevMask) { - Engine::Get()->Push([from, ret](RunContext ctx) { + Engine::Get()->PushSync([from, ret](RunContext ctx) { ret.ptr_->CheckAndAlloc(); TBlob tmp = ret.data(); narray::Copy(from.data(), &tmp, @@ -183,7 +183,7 @@ void CopyFromTo(const NArray &from, NArray *to) { ctx.get_stream()->Wait(); }, ret.ctx(), const_vars, {ret.ptr_->var}); } else if (a == gpu::kDevMask && b == cpu::kDevMask) { - Engine::Get()->Push([from, ret](RunContext ctx) { + Engine::Get()->PushSync([from, ret](RunContext ctx) { ret.ptr_->CheckAndAlloc(); TBlob tmp = ret.data(); narray::Copy(from.data(), &tmp, @@ -192,7 +192,7 @@ void CopyFromTo(const NArray &from, NArray *to) { ctx.get_stream()->Wait(); }, from.ctx(), const_vars, {ret.ptr_->var}); } else if (a == gpu::kDevMask && b == gpu::kDevMask) { - Engine::Get()->Push([from, ret](RunContext ctx) { + Engine::Get()->PushSync([from, ret](RunContext ctx) { ret.ptr_->CheckAndAlloc(); TBlob tmp = ret.data(); narray::Copy(from.data(), &tmp, diff --git a/src/symbol/graph_executor.cc b/src/symbol/graph_executor.cc index 1211f1a4abb4..2f0bd318cf67 100644 --- a/src/symbol/graph_executor.cc +++ b/src/symbol/graph_executor.cc @@ -202,7 +202,8 @@ GraphExecutor::GetOpExecEntry(uint32_t nid) { Operator* op = op_node.op.get(); OpContext* op_ctx_ptr = &op_node.op_ctx; bool is_gpu = op_node.ctx.dev_mask == gpu::kDevMask; - exec.exec_fun = [op, is_gpu, op_ctx_ptr, in_data, req, out_data, aux_states] (RunContext ctx) { + exec.exec_fun = [op, is_gpu, op_ctx_ptr, in_data, req, out_data, aux_states] + (RunContext ctx, Engine::CallbackOnComplete on_complete) { op_ctx_ptr->run_ctx = ctx; op->Forward(*op_ctx_ptr, in_data, req, out_data, aux_states); if (is_gpu) { @@ -213,6 +214,7 @@ GraphExecutor::GetOpExecEntry(uint32_t nid) { LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR; #endif } + on_complete(); }; return exec; } @@ -472,18 +474,20 @@ void GraphExecutor::RunOps(bool is_train, size_t topo_start, size_t topo_end) { OpNode& opnode = op_nodes_[nid]; opnode.op_ctx.is_train = is_train; if (opnode.cached_exec.exec_fun != nullptr) { - Engine::Get()->Push( + Engine::Get()->PushAsync( opnode.cached_exec.exec_fun, opnode.ctx, opnode.cached_exec.use_vars, - opnode.cached_exec.mutate_vars); + opnode.cached_exec.mutate_vars, + FnProperty::kNormal); } else { auto exec = GetOpExecEntry(nid); - Engine::Get()->Push( + Engine::Get()->PushAsync( exec.exec_fun, opnode.ctx, exec.use_vars, - exec.mutate_vars); + exec.mutate_vars, + FnProperty::kNormal); } } } diff --git a/src/symbol/graph_executor.h b/src/symbol/graph_executor.h index 5db11fcdf779..823f28b5398e 100644 --- a/src/symbol/graph_executor.h +++ b/src/symbol/graph_executor.h @@ -93,7 +93,7 @@ class GraphExecutor : public Executor { // all the information needed to push the op to engine struct OpExecEntry { // execution function for - Engine::Fn exec_fun; + Engine::AsyncFn exec_fun; // variables to read from std::vector use_vars; // variables to mutate diff --git a/tests/.gitignore b/tests/.gitignore index 8144904045d0..1b2fb8f6589a 100644 --- a/tests/.gitignore +++ b/tests/.gitignore @@ -1 +1,2 @@ *_test +*_unittest diff --git a/tests/cpp/.gitignore b/tests/cpp/.gitignore new file mode 100644 index 000000000000..b075466b417a --- /dev/null +++ b/tests/cpp/.gitignore @@ -0,0 +1 @@ +unittest diff --git a/tests/test_storage.cc b/tests/cpp/storage_unittest.cc similarity index 54% rename from tests/test_storage.cc rename to tests/cpp/storage_unittest.cc index 33995a055dc5..20a92f4daaf5 100644 --- a/tests/test_storage.cc +++ b/tests/cpp/storage_unittest.cc @@ -1,26 +1,22 @@ -/*! - * Copyright (c) 2015 by Contributors - * \file test_storage.cc - * \brief Test for storage. - */ #include -#include -#include "mxnet/storage.h" +#include +#include +#include -int main() { +TEST(Storage, basics) { constexpr size_t kSize = 1024; auto&& storage = mxnet::Storage::Get(); mxnet::Context context_cpu{}; auto&& handle = storage->Alloc(kSize, context_cpu); - assert(handle.ctx == context_cpu); - assert(handle.size == kSize); + ASSERT_EQ(handle.ctx, context_cpu); + ASSERT_EQ(handle.size, kSize); auto ptr = handle.dptr; storage->Free(handle); handle = storage->Alloc(kSize, context_cpu); - assert(handle.ctx == context_cpu); - assert(handle.size == kSize); - assert(handle.dptr == ptr); - printf("Success on CPU!\n"); + ASSERT_EQ(handle.ctx, context_cpu); + ASSERT_EQ(handle.size, kSize); + ASSERT_EQ(handle.dptr, ptr); + LOG(INFO) << "Success on CPU!\n"; #if MXNET_USE_CUDA mxnet::Context context_gpu{mxnet::gpu::kDevMask, 0}; @@ -30,10 +26,9 @@ int main() { ptr = handle.dptr; storage->Free(handle); handle = storage->Alloc(kSize, context_gpu); - assert(handle.ctx == context_gpu); - assert(handle.size == kSize); - assert(handle.dptr == ptr); - printf("Success on GPU!\n"); + ASSERT_EQ(handle.ctx, context_gpu); + ASSERT_EQ(handle.size, kSize); + ASSERT_EQ(handle.dptr, ptr); + LOG(INFO) << "Success on GPU!\n"; #endif // MXNET_USE_CUDA - return 0; } diff --git a/tests/test_threaded_engine.cc b/tests/cpp/threaded_engine_unittest.cc similarity index 87% rename from tests/test_threaded_engine.cc rename to tests/cpp/threaded_engine_unittest.cc index d3708711779a..35e0ca3124b0 100644 --- a/tests/test_threaded_engine.cc +++ b/tests/cpp/threaded_engine_unittest.cc @@ -1,18 +1,16 @@ -/*! - * Copyright (c) 2015 by Contributors - */ #include #include #include +#include #include #include #include -#include "mxnet/engine.h" +#include void Foo(mxnet::RunContext, int i) { printf("The fox says %d\n", i); } -int main() { +TEST(Engine, basics) { auto&& engine = mxnet::Engine::Get(); auto&& var = engine->NewVariable(); std::vector oprs; @@ -21,7 +19,7 @@ int main() { printf("============= Test #1 ==============\n"); for (int i = 0; i < 10; ++i) { oprs.push_back(engine->NewOperator( - [i](mxnet::RunContext ctx, mxnet::Engine::Callback cb) { + [i](mxnet::RunContext ctx, mxnet::Engine::CallbackOnComplete cb) { Foo(ctx, i); std::this_thread::sleep_for(std::chrono::seconds{1}); cb(); @@ -43,7 +41,7 @@ int main() { oprs.clear(); for (int i = 0; i < 10; ++i) { oprs.push_back(engine->NewOperator( - [i](mxnet::RunContext ctx, mxnet::Engine::Callback cb) { + [i](mxnet::RunContext ctx, mxnet::Engine::CallbackOnComplete cb) { Foo(ctx, i); std::this_thread::sleep_for(std::chrono::milliseconds{500}); cb(); @@ -69,12 +67,12 @@ int main() { var = engine->NewVariable(); oprs.clear(); oprs.push_back(engine->NewOperator( - [](mxnet::RunContext ctx, mxnet::Engine::Callback cb) { + [](mxnet::RunContext ctx, mxnet::Engine::CallbackOnComplete cb) { std::this_thread::sleep_for(std::chrono::seconds{2}); Foo(ctx, 42); cb(); }, - {}, {var}, mxnet::FnProperty::kIO)); + {}, {var}, mxnet::FnProperty::kCopy)); engine->Push(oprs.at(0), mxnet::Context{}); LOG(INFO) << "IO operator pushed, should wait for 2 seconds."; engine->WaitForVar(var); @@ -89,7 +87,7 @@ int main() { var = engine->NewVariable(); oprs.clear(); oprs.push_back(engine->NewOperator( - [](mxnet::RunContext ctx, mxnet::Engine::Callback cb) { + [](mxnet::RunContext ctx, mxnet::Engine::CallbackOnComplete cb) { Foo(ctx, 42); std::this_thread::sleep_for(std::chrono::seconds{2}); cb(); @@ -108,6 +106,5 @@ int main() { engine->WaitForAll(); var = nullptr; oprs.clear(); - - return 0; + LOG(INFO) << "All pass"; } diff --git a/tests/cpp/unittest.mk b/tests/cpp/unittest.mk new file mode 100644 index 000000000000..4020dba82a82 --- /dev/null +++ b/tests/cpp/unittest.mk @@ -0,0 +1,16 @@ +UNITTEST_SRC = $(wildcard tests/cpp/*_unittest.cc) +UNITTEST_OBJ = $(patsubst tests/cpp/%_unittest.cc, tests/cpp/%_unittest.o, $(UNITTEST_SRC)) + +GTEST_LIB=$(GTEST_PATH)/lib/ +GTEST_INC=$(GTEST_PATH)/include/ + +tests/cpp/%.o : tests/cpp/%.cc + $(CXX) -std=c++0x $(CFLAGS) -MM -MT tests/$*.o $< >tests/$*.d + $(CXX) -std=c++0x -c $(CFLAGS) -I$(GTEST_INC) -c $< -o $@ + +tests/cpp/unittest: $(UNITTEST_OBJ) lib/libmxnet.a + $(CXX) $(CFLAGS) -std=c++0x -o $@ $(filter %.o %.a, $^) $(LDFLAGS) -lgtest -lgtest_main + +-include tests/cpp/*.d + + diff --git a/tests/python/README.md b/tests/python/README.md new file mode 100644 index 000000000000..02dcb6ea6818 --- /dev/null +++ b/tests/python/README.md @@ -0,0 +1,10 @@ +Python Test Case +================ +This folder contains test cases for mxnet in python. + +* [common](common) contains common utils for all test modules. + - From subfolders, import with ```from ..common import get_data``` +* [unittest](unittest) contains unit test component for each modules. + - These are basic tests that must pass for every commit. +* [train](train) contains tests that runs on real network training. + - These tests can be time consuming. diff --git a/tests/python/get_data.py b/tests/python/common/get_data.py similarity index 100% rename from tests/python/get_data.py rename to tests/python/common/get_data.py diff --git a/tests/python/models.py b/tests/python/common/models.py similarity index 100% rename from tests/python/models.py rename to tests/python/common/models.py diff --git a/tests/python/train/common.py b/tests/python/train/common.py new file mode 100644 index 000000000000..1622e0294e69 --- /dev/null +++ b/tests/python/train/common.py @@ -0,0 +1,6 @@ +import sys, os +curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__))) +sys.path.append(os.path.join(curr_path, '../common/')) + +import models +import get_data diff --git a/tests/python/test_conv.py b/tests/python/train/test_conv.py similarity index 99% rename from tests/python/test_conv.py rename to tests/python/train/test_conv.py index d63a0542ce7a..4affe6d8f200 100644 --- a/tests/python/test_conv.py +++ b/tests/python/train/test_conv.py @@ -2,8 +2,7 @@ import mxnet as mx import numpy as np import os, pickle, gzip -import sys -import get_data +from common import get_data def CalAcc(out, label): pred = np.argmax(out, axis=1) diff --git a/tests/python/test_mlp.py b/tests/python/train/test_mlp.py similarity index 98% rename from tests/python/test_mlp.py rename to tests/python/train/test_mlp.py index 85abbb9ac216..e2b6dcee8488 100644 --- a/tests/python/test_mlp.py +++ b/tests/python/train/test_mlp.py @@ -3,9 +3,8 @@ import numpy as np import os, gzip import pickle as pickle -import sys -import get_data - +from common import get_data + def CalAcc(out, label): pred = np.argmax(out, axis=1) return np.sum(pred == label) * 1.0 / out.shape[0] diff --git a/tests/python/unittest/common.py b/tests/python/unittest/common.py new file mode 100644 index 000000000000..1622e0294e69 --- /dev/null +++ b/tests/python/unittest/common.py @@ -0,0 +1,6 @@ +import sys, os +curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__))) +sys.path.append(os.path.join(curr_path, '../common/')) + +import models +import get_data diff --git a/tests/python/test_bind.py b/tests/python/unittest/test_bind.py similarity index 100% rename from tests/python/test_bind.py rename to tests/python/unittest/test_bind.py diff --git a/tests/python/test_infer_shape.py b/tests/python/unittest/test_infer_shape.py similarity index 97% rename from tests/python/test_infer_shape.py rename to tests/python/unittest/test_infer_shape.py index b7f1efd75225..d80dbb713548 100644 --- a/tests/python/test_infer_shape.py +++ b/tests/python/unittest/test_infer_shape.py @@ -1,6 +1,6 @@ # pylint: skip-file import mxnet as mx -import models +from common import models from nose.tools import * def test_mlp2_infer_shape(): diff --git a/tests/python/test_io.py b/tests/python/unittest/test_io.py similarity index 98% rename from tests/python/test_io.py rename to tests/python/unittest/test_io.py index 54b538f13eba..e606f9254b5a 100644 --- a/tests/python/test_io.py +++ b/tests/python/unittest/test_io.py @@ -3,8 +3,7 @@ import numpy as np import os, gzip import pickle as pickle -import sys -import get_data +from common import get_data #from PIL import Image @@ -104,4 +103,4 @@ def test_Cifar10Rec(): ''' if __name__ == "__main__": - test_MNISTIter() \ No newline at end of file + test_MNISTIter() diff --git a/tests/python/test_kvstore.py b/tests/python/unittest/test_kvstore.py similarity index 100% rename from tests/python/test_kvstore.py rename to tests/python/unittest/test_kvstore.py diff --git a/tests/python/test_narray.py b/tests/python/unittest/test_narray.py similarity index 100% rename from tests/python/test_narray.py rename to tests/python/unittest/test_narray.py diff --git a/tests/python/test_operator.py b/tests/python/unittest/test_operator.py similarity index 100% rename from tests/python/test_operator.py rename to tests/python/unittest/test_operator.py diff --git a/tests/python/test_symbol.py b/tests/python/unittest/test_symbol.py similarity index 97% rename from tests/python/test_symbol.py rename to tests/python/unittest/test_symbol.py index b08f6a310570..b4dc93e1cfdd 100644 --- a/tests/python/test_symbol.py +++ b/tests/python/unittest/test_symbol.py @@ -1,5 +1,5 @@ import mxnet as mx -import models +from common import models def test_symbol_basic(): mlist = [] From 72bab4cfcda462d06ac516f9526857bfd6f5451b Mon Sep 17 00:00:00 2001 From: tqchen Date: Sat, 12 Sep 2015 21:50:21 -0700 Subject: [PATCH 06/13] update doc, fix test --- doc/sphinx_util.py | 1 + scripts/travis_script.sh | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/doc/sphinx_util.py b/doc/sphinx_util.py index 10ed6fddbe46..fd8aaf82e069 100644 --- a/doc/sphinx_util.py +++ b/doc/sphinx_util.py @@ -14,6 +14,7 @@ def run_build_mxnet(folder): subprocess.call('cd ..; rm -rf mshadow;' + 'git clone https://github.com/dmlc/mshadow', shell = True) subprocess.call('cd ..; cp make/readthedocs.mk config.mk', shell = True) + subprocess.call('cd ..; rm -rf build', shell = True) retcode = subprocess.call("cd %s; make" % folder, shell = True) if retcode < 0: sys.stderr.write("build terminated by signal %s" % (-retcode)) diff --git a/scripts/travis_script.sh b/scripts/travis_script.sh index 99d1771d1ac7..9f6fc9b159e6 100755 --- a/scripts/travis_script.sh +++ b/scripts/travis_script.sh @@ -55,7 +55,7 @@ if [ ${TASK} == "cpp_unittest" ]; then echo "USE_CUDA=0" >> config.mk make test || exit -1 export MXNET_ENGINE_TYPE=NaiveEngine - testsp/cpp/unittest || exit -1 + tests/cpp/unittest || exit -1 export MXNET_ENGINE_TYPE=ThreadedEngine tests/cpp/unittest || exit -1 fi From 47bdb1d223e716c9e8fe464d8115e99354e91534 Mon Sep 17 00:00:00 2001 From: tqchen Date: Sat, 12 Sep 2015 21:59:11 -0700 Subject: [PATCH 07/13] only unit test threaded engine --- scripts/travis_script.sh | 2 -- 1 file changed, 2 deletions(-) diff --git a/scripts/travis_script.sh b/scripts/travis_script.sh index 9f6fc9b159e6..1b250afdf70b 100755 --- a/scripts/travis_script.sh +++ b/scripts/travis_script.sh @@ -54,8 +54,6 @@ fi if [ ${TASK} == "cpp_unittest" ]; then echo "USE_CUDA=0" >> config.mk make test || exit -1 - export MXNET_ENGINE_TYPE=NaiveEngine - tests/cpp/unittest || exit -1 export MXNET_ENGINE_TYPE=ThreadedEngine tests/cpp/unittest || exit -1 fi From d747857d1a055e839b015176eee225598ba9de37 Mon Sep 17 00:00:00 2001 From: Bing Xu Date: Sat, 12 Sep 2015 23:38:05 -0600 Subject: [PATCH 08/13] minor switch to kwargs --- example/cifar10/cifar10.py | 2 +- python/mxnet/symbol.py | 19 +++++++++++-------- 2 files changed, 12 insertions(+), 9 deletions(-) diff --git a/example/cifar10/cifar10.py b/example/cifar10/cifar10.py index 892f6fff5d1d..9b387b8d297a 100644 --- a/example/cifar10/cifar10.py +++ b/example/cifar10/cifar10.py @@ -162,7 +162,7 @@ def RandomInit(narray): data_shape = (batch_size, 3, 28, 28) in_data = mx.narray.empty(data_shape, mx.gpu()) -executor = loss.simple_bind(mx.gpu(), {"data": in_data}) +executor = loss.simple_bind(mx.gpu(), data = in_data) out_narray = executor.heads()[0] pred = mx.narray.zeros(out_narray.shape, mx.cpu()) diff --git a/python/mxnet/symbol.py b/python/mxnet/symbol.py index 90df0b663615..4cf15d7c60f5 100644 --- a/python/mxnet/symbol.py +++ b/python/mxnet/symbol.py @@ -332,14 +332,19 @@ def _get_narray_handle(arg_key, args, arg_names, allow_missing): raise TypeError('Only Accept list of NArrays or dict of str->NArray') return c_array(NArrayHandle, arg_handles) - def simple_bind(self, ctx, args, grad_req='write'): + def simple_bind(self, ctx, grad_req='write', **kwargs): """Simply bind current symbol to get an executor Parameters ---------- ctx : Context The device context the generated executor to run on. - - args : list of NArray or dict of str->NArray + grad_req: string + {'write', 'add', 'null'}, or list of str or dict of str->str, optional + Specifies how we should update the gradient to the args_grad. + - 'write' means everytime gradient is write to specified args_grad NArray. + - 'add' means everytime gradient is add to the specified NArray. + - 'null' means no action is taken, the gradient may not be calculated. + kwargs : dict of str->NArray Input arguments to the symbol. - type is dict of str->NArray, then it maps the name of arguments to the corresponding NArray, @@ -349,9 +354,7 @@ def simple_bind(self, ctx, args, grad_req='write'): executor : mxnet.Executor The generated Executor """ - if not isinstance(args, dict): - raise TypeError("args must be dict of str->NArray") - input_shapes = dict((name, arr.shape) for name, arr in args.items()) + input_shapes = dict((name, arr.shape) for name, arr in kwargs.items()) # pylint: disable=unused-variable arg_shapes, out_shapes, aux_shapes = self.infer_shape(**input_shapes) # pylint: enable=unused-variable @@ -360,8 +363,8 @@ def simple_bind(self, ctx, args, grad_req='write'): # alloc space arg_narrays = [] for name, shape in zip(self.list_arguments(), arg_shapes): - if name in args: - arg_narrays.append(args[name]) + if name in kwargs: + arg_narrays.append(kwargs[name]) else: arg_narrays.append(zeros(shape, ctx)) # TODO(bing): specail treat input data grad From 65d51108d24b653b21d9b74b3fce910c8658ef31 Mon Sep 17 00:00:00 2001 From: tqchen Date: Sun, 13 Sep 2015 10:41:34 -0700 Subject: [PATCH 09/13] rename narray->ndarray --- Makefile | 9 +- include/mxnet/c_api.h | 273 ++++++---------- include/mxnet/context.h | 1 + include/mxnet/io.h | 5 +- include/mxnet/kvstore.h | 14 +- include/mxnet/{narray.h => ndarray.h} | 292 ++++++++--------- include/mxnet/operator.h | 2 +- include/mxnet/symbolic.h | 20 +- make/config.mk | 28 +- python/mxnet/__init__.py | 4 +- python/mxnet/base.py | 4 +- python/mxnet/executor.py | 34 +- python/mxnet/io.py | 14 +- python/mxnet/kvstore.py | 26 +- python/mxnet/{narray.py => ndarray.py} | 302 +++++++++--------- python/mxnet/symbol.py | 106 +++--- src/c_api.cc | 234 +++++++------- src/kvstore/kvstore_local.h | 22 +- src/{narray/narray.cc => ndarray/ndarray.cc} | 198 ++++++------ .../ndarray_function-inl.h} | 28 +- .../ndarray_function.cc} | 10 +- .../ndarray_function.cu} | 8 +- .../ndarray_function.h} | 16 +- src/symbol/graph_executor.cc | 28 +- src/symbol/graph_executor.h | 30 +- src/symbol/graph_memory_allocator.h | 22 +- tests/python/train/test_conv.py | 8 +- tests/python/train/test_mlp.py | 8 +- tests/python/unittest/test_bind.py | 12 +- tests/python/unittest/test_kvstore.py | 20 +- .../{test_narray.py => test_ndarray.py} | 20 +- tests/python/unittest/test_operator.py | 12 +- 32 files changed, 849 insertions(+), 961 deletions(-) rename include/mxnet/{narray.h => ndarray.h} (59%) rename python/mxnet/{narray.py => ndarray.py} (64%) rename src/{narray/narray.cc => ndarray/ndarray.cc} (64%) rename src/{narray/narray_function-inl.h => ndarray/ndarray_function-inl.h} (88%) rename src/{narray/narray_function.cc => ndarray/ndarray_function.cc} (73%) rename src/{narray/narray_function.cu => ndarray/ndarray_function.cu} (93%) rename src/{narray/narray_function.h => ndarray/ndarray_function.h} (80%) rename tests/python/unittest/{test_narray.py => test_ndarray.py} (87%) diff --git a/Makefile b/Makefile index 879f534b0a9c..abdc4d7b5444 100644 --- a/Makefile +++ b/Makefile @@ -54,14 +54,7 @@ else CFLAGS+= -DMXNET_USE_OPENCV=0 endif -# setup opencv -ifeq ($(USE_OPENCV_DECODER),1) - CFLAGS+= -DMXNET_USE_OPENCV_DECODER=1 -else - CFLAGS+= -DMXNET_USE_OPENCV_DECODER=0 -endif - -ifeq ($(USE_OPENMP_ITER), 1) +ifeq ($(USE_OPENMP), 1) CFLAGS += -fopenmp endif diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h index 3c74b1c93605..494807a47fcc 100644 --- a/include/mxnet/c_api.h +++ b/include/mxnet/c_api.h @@ -26,9 +26,9 @@ typedef float mx_float; // all the handles are simply void * // will be casted internally to specific pointers types // these typedefs are mainly used for readablity reasons -/*! \brief handle to NArray */ -typedef void *NArrayHandle; -/*! \brief handle to a mxnet narray function that changes NArray */ +/*! \brief handle to NDArray */ +typedef void *NDArrayHandle; +/*! \brief handle to a mxnet narray function that changes NDArray */ typedef const void *FunctionHandle; /*! \brief handle to a function that takes param and creates symbol */ typedef void *AtomicSymbolCreator; @@ -53,18 +53,18 @@ typedef void *DataIterHandle; */ MXNET_DLL const char *MXGetLastError(); //------------------------------------- -// Part 1: NArray creation and deletion +// Part 1: NDArray creation and deletion //------------------------------------- /*! - * \brief create a NArray handle that is not initialized + * \brief create a NDArray handle that is not initialized * can be used to pass in as mutate variables - * to hold the result of NArray + * to hold the result of NDArray * \param out the returning handle * \return 0 when success, -1 when failure happens */ -MXNET_DLL int MXNArrayCreateNone(NArrayHandle *out); +MXNET_DLL int MXNDArrayCreateNone(NDArrayHandle *out); /*! - * \brief create a NArray with specified shape + * \brief create a NDArray with specified shape * \param shape the pointer to the shape * \param ndim the dimension of the shape * \param dev_mask device mask, specify device we want to take @@ -74,43 +74,43 @@ MXNET_DLL int MXNArrayCreateNone(NArrayHandle *out); * \param out the returning handle * \return 0 when success, -1 when failure happens */ -MXNET_DLL int MXNArrayCreate(const mx_uint *shape, - mx_uint ndim, - int dev_mask, - int dev_id, - int delay_alloc, - NArrayHandle *out); +MXNET_DLL int MXNDArrayCreate(const mx_uint *shape, + mx_uint ndim, + int dev_mask, + int dev_id, + int delay_alloc, + NDArrayHandle *out); /*! - * \brief create a NArray handle that is loaded from raw bytes. + * \brief create a NDArray handle that is loaded from raw bytes. * \param buf the head of the raw bytes * \param size size of the raw bytes * \param out the returning handle * \return 0 when success, -1 when failure happens */ -MXNET_DLL int MXNArrayLoadFromRawBytes(const void *buf, - mx_ulong size, - NArrayHandle *out); +MXNET_DLL int MXNDArrayLoadFromRawBytes(const void *buf, + mx_ulong size, + NDArrayHandle *out); /*! - * \brief save the NArray into raw bytes. - * \param handle the NArray handle + * \brief save the NDArray into raw bytes. + * \param handle the NDArray handle * \param out_size size of the raw bytes * \param out_buf the head of returning memory bytes. * \return 0 when success, -1 when failure happens */ -MXNET_DLL int MXNArraySaveRawBytes(NArrayHandle handle, - mx_ulong *out_size, - const char **out_buf); +MXNET_DLL int MXNDArraySaveRawBytes(NDArrayHandle handle, + mx_ulong *out_size, + const char **out_buf); /*! * \brief Save list of narray into the file. * \param fname name of the file. * \param num_args number of arguments to save. - * \param args the array of NArrayHandles to be saved. - * \param keys the name of the NArray, optional, can be NULL + * \param args the array of NDArrayHandles to be saved. + * \param keys the name of the NDArray, optional, can be NULL * \return 0 when success, -1 when failure happens */ -MXNET_DLL int MXNArrayListSave(const char* fname, +MXNET_DLL int MXNDArrayListSave(const char* fname, mx_uint num_args, - NArrayHandle* args, + NDArrayHandle* args, const char** keys); /*! * \brief Load list of narray from the file. @@ -118,68 +118,68 @@ MXNET_DLL int MXNArrayListSave(const char* fname, * \param out_size number of narray loaded. * \param out_arr head of the returning narray handles. * \param out_name_size size of output name arrray. - * \param out_names the names of returning NArrays, can be NULL + * \param out_names the names of returning NDArrays, can be NULL * \return 0 when success, -1 when failure happens */ -MXNET_DLL int MXNArrayListLoad(const char* fname, - mx_uint *out_size, - NArrayHandle** out_arr, - mx_uint *out_name_size, - const char*** out_names); +MXNET_DLL int MXNDArrayListLoad(const char* fname, + mx_uint *out_size, + NDArrayHandle** out_arr, + mx_uint *out_name_size, + const char*** out_names); /*! * \brief Perform a synchronize copy from a continugous CPU memory region. * * This function will call WaitToWrite before the copy is performed. * This is useful to copy data from existing memory region that are - * not wrapped by NArray(thus dependency not being tracked). + * not wrapped by NDArray(thus dependency not being tracked). * - * \param handle the NArray handle + * \param handle the NDArray handle * \param data the data source to copy from. * \param size the memory size we want to copy from. */ -MXNET_DLL int MXNArraySyncCopyFromCPU(NArrayHandle handle, - const mx_float *data, - size_t size); +MXNET_DLL int MXNDArraySyncCopyFromCPU(NDArrayHandle handle, + const mx_float *data, + size_t size); /*! * \brief Perform a synchronize copyto a continugous CPU memory region. * * This function will call WaitToRead before the copy is performed. * This is useful to copy data from existing memory region that are - * not wrapped by NArray(thus dependency not being tracked). + * not wrapped by NDArray(thus dependency not being tracked). * - * \param handle the NArray handle + * \param handle the NDArray handle * \param data the data source to copy into. * \param size the memory size we want to copy into. */ -MXNET_DLL int MXNArraySyncCopyToCPU(NArrayHandle handle, +MXNET_DLL int MXNDArraySyncCopyToCPU(NDArrayHandle handle, mx_float *data, size_t size); /*! - * \brief Wait until all the pending writes with respect NArray are finished. + * \brief Wait until all the pending writes with respect NDArray are finished. * Always call this before read data out synchronizely. - * \param handle the NArray handle + * \param handle the NDArray handle * \return 0 when success, -1 when failure happens */ -MXNET_DLL int MXNArrayWaitToRead(NArrayHandle handle); +MXNET_DLL int MXNDArrayWaitToRead(NDArrayHandle handle); /*! - * \brief Wait until all the pending read/write with respect NArray are finished. - * Always call this before write data into NArray synchronizely. - * \param handle the NArray handle + * \brief Wait until all the pending read/write with respect NDArray are finished. + * Always call this before write data into NDArray synchronizely. + * \param handle the NDArray handle * \return 0 when success, -1 when failure happens */ -MXNET_DLL int MXNArrayWaitToWrite(NArrayHandle handle); +MXNET_DLL int MXNDArrayWaitToWrite(NDArrayHandle handle); /*! * \brief wait until all delayed operations in * the system is completed * \return 0 when success, -1 when failure happens */ -MXNET_DLL int MXNArrayWaitAll(); +MXNET_DLL int MXNDArrayWaitAll(); /*! * \brief free the narray handle * \param handle the handle to be freed * \return 0 when success, -1 when failure happens */ -MXNET_DLL int MXNArrayFree(NArrayHandle handle); +MXNET_DLL int MXNDArrayFree(NDArrayHandle handle); /*! * \brief get the shape of the array * \param handle the handle to the narray @@ -187,30 +187,30 @@ MXNET_DLL int MXNArrayFree(NArrayHandle handle); * \param out_pdata pointer holder to get data pointer of the shape * \return 0 when success, -1 when failure happens */ -MXNET_DLL int MXNArrayGetShape(NArrayHandle handle, +MXNET_DLL int MXNDArrayGetShape(NDArrayHandle handle, mx_uint *out_dim, const mx_uint **out_pdata); /*! - * \brief get the content of the data in NArray + * \brief get the content of the data in NDArray * \param handle the handle to the narray * \param out_pdata pointer holder to get pointer of data * \return 0 when success, -1 when failure happens */ -MXNET_DLL int MXNArrayGetData(NArrayHandle handle, +MXNET_DLL int MXNDArrayGetData(NDArrayHandle handle, mx_float **out_pdata); /*! - * \brief get the context of the NArray + * \brief get the context of the NDArray * \param handle the handle to the narray * \param out_dev_mask the output device mask * \param out_dev_id the output device id * \return 0 when success, -1 when failure happens */ -MXNET_DLL int MXNArrayGetContext(NArrayHandle handle, +MXNET_DLL int MXNDArrayGetContext(NDArrayHandle handle, int *out_dev_mask, int *out_dev_id); //-------------------------------- -// Part 2: functions on NArray +// Part 2: functions on NDArray //-------------------------------- /*! * \brief list all the available functions handles @@ -250,9 +250,9 @@ MXNET_DLL int MXFuncGetInfo(FunctionHandle fun, /*! * \brief get the argument requirements of the function * \param fun input function handle - * \param num_use_vars how many NArrays to be passed in as used_vars + * \param num_use_vars how many NDArrays to be passed in as used_vars * \param num_scalars scalar variable is needed - * \param num_mutate_vars how many NArrays to be passed in as mutate_vars + * \param num_mutate_vars how many NDArrays to be passed in as mutate_vars * \param type_mask the type mask of this function * \return 0 when success, -1 when failure happens * \sa MXFuncInvoke @@ -273,9 +273,9 @@ MXNET_DLL int MXFuncDescribe(FunctionHandle fun, * \sa MXFuncDescribeArgs */ MXNET_DLL int MXFuncInvoke(FunctionHandle fun, - NArrayHandle *use_vars, + NDArrayHandle *use_vars, mx_float *scalar_args, - NArrayHandle *mutate_vars); + NDArrayHandle *mutate_vars); //-------------------------------------------- // Part 3: symbolic configuration generation @@ -486,16 +486,16 @@ MXNET_DLL int MXExecutorForward(ExecutorHandle handle, bool is_train); * * \param handle execute handle * \param len lenth - * \param head_grads NArray handle for heads' gradient + * \param head_grads NDArray handle for heads' gradient * * \return 0 when success, -1 when failure happens */ MXNET_DLL int MXExecutorBackward(ExecutorHandle handle, mx_uint len, - NArrayHandle *head_grads); + NDArrayHandle *head_grads); /*! - * \brief Get executor's head NArray + * \brief Get executor's head NDArray * * \param handle executor handle * \param out_size output narray vector size @@ -504,7 +504,7 @@ MXNET_DLL int MXExecutorBackward(ExecutorHandle handle, */ MXNET_DLL int MXExecutorHeads(ExecutorHandle handle, mx_uint *out_size, - NArrayHandle **out); + NDArrayHandle **out); /*! * \brief Generate Executor from symbol @@ -525,11 +525,11 @@ MXNET_DLL int MXExecutorBind(SymbolHandle symbol_handle, int dev_mask, int dev_id, mx_uint len, - NArrayHandle *in_args, - NArrayHandle *arg_grad_store, + NDArrayHandle *in_args, + NDArrayHandle *arg_grad_store, mx_uint *grad_req_type, mx_uint aux_states_len, - NArrayHandle *aux_states, + NDArrayHandle *aux_states, ExecutorHandle *out); //-------------------------------------------- @@ -554,10 +554,10 @@ MXNET_DLL int MXListDataIters(mx_uint *out_size, * \return 0 when success, -1 when failure happens */ MXNET_DLL int MXDataIterCreateIter(DataIterCreator handle, - int num_param, - const char **keys, - const char **vals, - DataIterHandle *out); + int num_param, + const char **keys, + const char **vals, + DataIterHandle *out); /*! * \brief Get the detailed information about data iterator. * \param creator the DataIterCreator. @@ -570,106 +570,12 @@ MXNET_DLL int MXDataIterCreateIter(DataIterCreator handle, * \return 0 when success, -1 when failure happens */ MXNET_DLL int MXDataIterGetIterInfo(AtomicSymbolCreator creator, - const char **name, - const char **description, - mx_uint *num_args, - const char ***arg_names, - const char ***arg_type_infos, - const char ***arg_descriptions); -/*! - * \brief Free the handle to the IO module - * \param handle the handle pointer to the data iterator - * \return 0 when success, -1 when failure happens - */ -MXNET_DLL int MXDataIterFree(DataIterHandle handle); -/*! - * \brief get the name of iterator entry - * \param iter iterator entry - * \param out_name the name of the iterator - * \return 0 when success, -1 when failure happens - */ -MXNET_DLL int MXDataIterGetName(DataIterCreator iter, - const char **out_name); -/*! - * \brief Init an iterator, init with parameters - * the array size of passed in arguments - * \param handle of the iterator creator - * \param num_param number of parameter - * \param keys parameter keys - * \param vals parameter values - * \param out resulting iterator - * \return 0 when success, -1 when failure happens - */ -MXNET_DLL int MXDataIterCreateIter(DataIterCreator handle, - int num_param, - const char **keys, - const char **vals, - DataIterHandle *out); -/*! - * \brief Get the detailed information about data iterator. - * \param creator the DataIterCreator. - * \param name The returned name of the creator. - * \param description The returned description of the symbol. - * \param num_args Number of arguments. - * \param arg_names Name of the arguments. - * \param arg_type_infos Type informations about the arguments. - * \param arg_descriptions Description information about the arguments. - * \return 0 when success, -1 when failure happens - */ -MXNET_DLL int MXDataIterGetIterInfo(AtomicSymbolCreator creator, - const char **name, - const char **description, - mx_uint *num_args, - const char ***arg_names, - const char ***arg_type_infos, - const char ***arg_descriptions); -/*! - * \brief Free the handle to the IO module - * \param handle the handle pointer to the data iterator - * \return 0 when success, -1 when failure happens - */ -MXNET_DLL int MXDataIterFree(DataIterHandle handle); -/*! - * \brief Get the name of iterator entry - * \param iter iterator entry - * \param out_name the name of the iterator - * \return 0 when success, -1 when failure happens - */ -MXNET_DLL int MXDataIterGetName(DataIterCreator iter, - const char **out_name); -/*! - * \brief Init an iterator, init with parameters - * the array size of passed in arguments - * \param handle of the iterator creator - * \param num_param number of parameter - * \param keys parameter keys - * \param vals parameter values - * \param out resulting iterator - * \return 0 when success, -1 when failure happens - */ -MXNET_DLL int MXDataIterCreateIter(DataIterCreator handle, - int num_param, - const char **keys, - const char **vals, - DataIterHandle *out); -/*! - * \brief Get the detailed information about data iterator. - * \param creator the DataIterCreator. - * \param name The returned name of the creator. - * \param description The returned description of the symbol. - * \param num_args Number of arguments. - * \param arg_names Name of the arguments. - * \param arg_type_infos Type informations about the arguments. - * \param arg_descriptions Description information about the arguments. - * \return 0 when success, -1 when failure happens - */ -MXNET_DLL int MXDataIterGetIterInfo(AtomicSymbolCreator creator, - const char **name, - const char **description, - mx_uint *num_args, - const char ***arg_names, - const char ***arg_type_infos, - const char ***arg_descriptions); + const char **name, + const char **description, + mx_uint *num_args, + const char ***arg_names, + const char ***arg_type_infos, + const char ***arg_descriptions); /*! * \brief Free the handle to the IO module * \param handle the handle pointer to the data iterator @@ -683,7 +589,7 @@ MXNET_DLL int MXDataIterFree(DataIterHandle handle); * \return 0 when success, -1 when failure happens */ MXNET_DLL int MXDataIterNext(DataIterHandle handle, - int *out); + int *out); /*! * \brief Call iterator.Reset * \param handle the handle to iterator @@ -692,21 +598,24 @@ MXNET_DLL int MXDataIterNext(DataIterHandle handle, MXNET_DLL int MXDataIterBeforeFirst(DataIterHandle handle); /*! - * \brief Get the handle to the NArray of underlying data + * \brief Get the handle to the NDArray of underlying data * \param handle the handle pointer to the data iterator - * \param out handle to underlying data NArray + * \param out handle to underlying data NDArray * \return 0 when success, -1 when failure happens */ MXNET_DLL int MXDataIterGetData(DataIterHandle handle, - NArrayHandle *out); + NDArrayHandle *out); /*! - * \brief Get the handle to the NArray of underlying label + * \brief Get the handle to the NDArray of underlying label * \param handle the handle pointer to the data iterator - * \param out the handle to underlying label NArray + * \param out the handle to underlying label NDArray * \return 0 when success, -1 when failure happens */ MXNET_DLL int MXDataIterGetLabel(DataIterHandle handle, - NArrayHandle *out); + NDArrayHandle *out); +//-------------------------------------------- +// Part 5: KVStore interface +//-------------------------------------------- /*! * \brief start the kvstore * \return 0 when success, -1 when failure happens @@ -727,7 +636,7 @@ MXNET_DLL int MXKVStoreStop(); */ MXNET_DLL int MXKVStoreInit(int num, int* keys, - NArrayHandle* vals); + NDArrayHandle* vals); /*! * \brief Push a list of (key,value) pairs to kvstore @@ -738,7 +647,7 @@ MXNET_DLL int MXKVStoreInit(int num, */ MXNET_DLL int MXKVStorePush(int num, int* keys, - NArrayHandle* vals); + NDArrayHandle* vals); /*! @@ -750,7 +659,7 @@ MXNET_DLL int MXKVStorePush(int num, */ MXNET_DLL int MXKVStorePull(int num, int* keys, - NArrayHandle* vals); + NDArrayHandle* vals); /*! * \brief user-defined updater for the kvstore @@ -759,7 +668,7 @@ MXNET_DLL int MXKVStorePull(int num, * \param recv the pushed value on this key * \param local the value stored on local on this key */ -typedef void (MXKVStoreUpdater)(int key, NArrayHandle recv, NArrayHandle local); +typedef void (MXKVStoreUpdater)(int key, NDArrayHandle recv, NDArrayHandle local); /*! * \brief register an push updater diff --git a/include/mxnet/context.h b/include/mxnet/context.h index c0a712bc8ec8..a7ed35d21263 100644 --- a/include/mxnet/context.h +++ b/include/mxnet/context.h @@ -5,6 +5,7 @@ */ #ifndef MXNET_CONTEXT_H_ #define MXNET_CONTEXT_H_ + #include #include #include diff --git a/include/mxnet/io.h b/include/mxnet/io.h index 7bb86f4eece3..43dd5fad92d1 100644 --- a/include/mxnet/io.h +++ b/include/mxnet/io.h @@ -5,6 +5,7 @@ */ #ifndef MXNET_IO_H_ #define MXNET_IO_H_ + #include #include #include @@ -23,7 +24,7 @@ class IIterator : public dmlc::DataIter { /*! * \brief set the parameters and init iter * \param kwargs key-value pairs - */ + */ virtual void Init(const std::vector >& kwargs) = 0; /*! \brief reset the iterator */ virtual void BeforeFirst(void) = 0; @@ -33,7 +34,7 @@ class IIterator : public dmlc::DataIter { virtual const DType &Value(void) const = 0; /*! \brief constructor */ virtual ~IIterator(void) {} - /*! \brief store the name of each data, it could be used for making NArrays */ + /*! \brief store the name of each data, it could be used for making NDArrays */ std::vector data_names; /*! \brief set data name to each attribute of data */ inline void SetDataName(const std::string data_name){ diff --git a/include/mxnet/kvstore.h b/include/mxnet/kvstore.h index ef4be5102578..0acb05f0fa9d 100644 --- a/include/mxnet/kvstore.h +++ b/include/mxnet/kvstore.h @@ -10,7 +10,7 @@ #if DMLC_USE_CXX11 #include #endif // DMLC_USE_CXX11 -#include "narray.h" +#include "ndarray.h" namespace mxnet { @@ -52,7 +52,7 @@ class KVStore { * \param values a list of values */ virtual void Init(const std::vector& keys, - const std::vector& values) { + const std::vector& values) { CHECK_EQ(keys.size(), values.size()); get_impl()->Init(keys, values); } @@ -80,14 +80,14 @@ class KVStore { * for (auto& v : values) v.WaitToWrite() * \endcode * - * One must call Init() on every key before. And the value Narray should be + * One must call Init() on every key before. And the value NDArray should be * always has the same shape as being inited. * * \param keys the list of keys * \param value the list of values */ virtual void Push(const std::vector& keys, - const std::vector& values) { + const std::vector& values) { CHECK_EQ(keys.size(), values.size()); if (keys.empty()) return; get_impl()->Push(keys, values); @@ -110,7 +110,7 @@ class KVStore { * \param values the list of buffers for the pulled data, they should be preallocated */ virtual void Pull(const std::vector& keys, - const std::vector& values) { + const std::vector& values) { get_impl()->Pull(keys, values); } @@ -118,11 +118,11 @@ class KVStore { /** * \brief the prototype of user-defined updater */ - typedef std::function Updater; + typedef std::function Updater; /*! \brief returns the default updater, which is ASSIGN */ Updater DefaultUpdater() { - return [](int key, const NArray& a, NArray* b) { CopyFromTo(a, b); }; + return [](int key, const NDArray& a, NDArray* b) { CopyFromTo(a, b); }; } /** diff --git a/include/mxnet/narray.h b/include/mxnet/ndarray.h similarity index 59% rename from include/mxnet/narray.h rename to include/mxnet/ndarray.h index 20372524b3ac..c8ec4528202e 100644 --- a/include/mxnet/narray.h +++ b/include/mxnet/ndarray.h @@ -1,10 +1,10 @@ /*! * Copyright (c) 2015 by Contributors - * \file narray.h - * \brief narray interface that dynamically schedules operations + * \file ndarray.h + * \brief NDArray interface that handles array arithematics. */ -#ifndef MXNET_NARRAY_H_ -#define MXNET_NARRAY_H_ +#ifndef MXNET_NDARRAY_H_ +#define MXNET_NDARRAY_H_ #include #include @@ -19,39 +19,39 @@ #include "./engine.h" // check c++11 #if DMLC_USE_CXX11 == 0 -#error "cxx11 was required for narray module" +#error "cxx11 was required for ndarray module" #endif namespace mxnet { /*! * \brief ndarray interface */ -class NArray { +class NDArray { public: /*! \brief default cosntructor */ - NArray() {} + NDArray() {} /*! - * \brief constructing a new dynamic NArray + * \brief constructing a new dynamic NDArray * \param shape the shape of array - * \param ctx context of NArray + * \param ctx context of NDArray * \param delay_alloc whether delay the allocation */ - NArray(const TShape &shape, Context ctx, - bool delay_alloc = false) + NDArray(const TShape &shape, Context ctx, + bool delay_alloc = false) : ptr_(std::make_shared(shape.Size(), ctx, delay_alloc)), shape_(shape), offset_(0) { } /*! - * \brief constructing a static NArray that shares data with TBlob - * Use with caution: allocate ONLY ONE NArray for each TBlob, - * make sure the memory region is available through out the life of NArray + * \brief constructing a static NDArray that shares data with TBlob + * Use with caution: allocate ONLY ONE NDArray for each TBlob, + * make sure the memory region is available through out the life of NDArray * \param data the memory content of static data * \param dev_id the device id this tensor sits at */ - NArray(const TBlob &data, int dev_id) + NDArray(const TBlob &data, int dev_id) : ptr_(std::make_shared(data, dev_id)), shape_(data.shape_), offset_(0) { } /*! - * \return the shape of current NArray + * \return the shape of current NDArray */ inline const TShape &shape() const { return shape_; @@ -64,18 +64,18 @@ class NArray { shape_, ptr_->shandle.ctx.dev_mask); } /*! - * \return the context of NArray, this function is only valid when the NArray is not empty + * \return the context of NDArray, this function is only valid when the NDArray is not empty */ inline Context ctx() const { return ptr_->shandle.ctx; } - /*! \return whether this narray is not initialized */ + /*! \return whether this ndarray is not initialized */ inline bool is_none() const { return ptr_.get() == nullptr; } /*! * \brief Block until all the pending write operations with respect - * to current NArray are finished, and read can be performed. + * to current NDArray are finished, and read can be performed. */ inline void WaitToRead() const { if (is_none()) return; @@ -83,7 +83,7 @@ class NArray { } /*! * \brief Block until all the pending read/write operations with respect - * to current NArray are finished, and write can be performed. + * to current NDArray are finished, and write can be performed. */ inline void WaitToWrite() const { if (is_none()) return; @@ -94,7 +94,7 @@ class NArray { Engine::Get()->PushSync([](RunContext) {}, Context{}, {}, {ptr_->var}); Engine::Get()->WaitForVar(ptr_->var); } - /*! \return the associated variable of the narray.*/ + /*! \return the associated variable of the ndarray.*/ inline Engine::VarHandle var() const { return ptr_->var; } @@ -110,84 +110,84 @@ class NArray { */ bool Load(dmlc::Stream *strm); /*! - * \brief set all the elements in narray to be scalar + * \brief set all the elements in ndarray to be scalar * \param scalar the scalar to set * \return reference of self */ - NArray &operator=(real_t scalar); + NDArray &operator=(real_t scalar); /*! * \brief elementwise add to current space - * this mutate the current NArray + * this mutate the current NDArray * \param src the data to add * \return reference of self */ - NArray &operator+=(const NArray &src); + NDArray &operator+=(const NDArray &src); /*! * \brief elementwise add to current space - * this mutate the current NArray + * this mutate the current NDArray * \param src the data to add * \return reference of self */ - NArray &operator+=(const real_t &src); + NDArray &operator+=(const real_t &src); /*! - * \brief elementwise subtract from current narray - * this mutate the current NArray + * \brief elementwise subtract from current ndarray + * this mutate the current NDArray * \param src the data to substract * \return reference of self */ - NArray &operator-=(const NArray &src); + NDArray &operator-=(const NDArray &src); /*! - * \brief elementwise subtract from current narray - * this mutate the current NArray + * \brief elementwise subtract from current ndarray + * this mutate the current NDArray * \param src the data to substract * \return reference of self */ - NArray &operator-=(const real_t &src); + NDArray &operator-=(const real_t &src); /*! - * \brief elementwise multiplication to current narray - * this mutate the current NArray + * \brief elementwise multiplication to current ndarray + * this mutate the current NDArray * \param src the data to substract * \return reference of self */ - NArray &operator*=(const NArray &src); + NDArray &operator*=(const NDArray &src); /*! - * \brief elementwise multiplication to current narray - * this mutate the current NArray + * \brief elementwise multiplication to current ndarray + * this mutate the current NDArray * \param src the data to substract * \return reference of self */ - NArray &operator*=(const real_t &src); + NDArray &operator*=(const real_t &src); /*! - * \brief elementwise division from current narray - * this mutate the current NArray + * \brief elementwise division from current ndarray + * this mutate the current NDArray * \param src the data to substract * \return reference of self */ - NArray &operator/=(const NArray &src); + NDArray &operator/=(const NDArray &src); /*! - * \brief elementwise division from current narray - * this mutate the current NArray + * \brief elementwise division from current ndarray + * this mutate the current NDArray * \param src the data to substract * \return reference of self */ - NArray &operator/=(const real_t &src); + NDArray &operator/=(const real_t &src); /*! - * \brief return transpose of current NArray - * \return a new transposed NArray + * \brief return transpose of current NDArray + * \return a new transposed NDArray */ - NArray T() const; + NDArray T() const; /*! - * \brief return a new copy this NArray - * \param ctx the new context of this NArray + * \brief return a new copy this NDArray + * \param ctx the new context of this NDArray * \return the new copy */ - NArray Copy(Context ctx) const; + NDArray Copy(Context ctx) const; /*! * \brief Do a synchronize copy from a continugous CPU memory region. * * This function will call WaitToWrite before the copy is performed. * This is useful to copy data from existing memory region that are - * not wrapped by NArray(thus dependency not being tracked). + * not wrapped by NDArray(thus dependency not being tracked). * * \param data the data source to copy from. * \param size the memory size we want to copy from. @@ -198,21 +198,21 @@ class NArray { * * This function will call WaitToRead before the copy is performed. * This is useful to copy data from existing memory region that are - * not wrapped by NArray(thus dependency not being tracked). + * not wrapped by NDArray(thus dependency not being tracked). * * \param data the data source to copyinto. * \param size the memory size we want to copy into. */ void SyncCopyToCPU(real_t *data, size_t size) const; /*! - * \brief Slice a NArray + * \brief Slice a NDArray * \param begin begin index in first dim * \param end end index in first dim - * \return sliced NArray + * \return sliced NDArray */ - inline NArray Slice(index_t begin, index_t end) const { - NArray ret = *this; - CHECK(!is_none()) << "NArray is not initialized"; + inline NDArray Slice(index_t begin, index_t end) const { + NDArray ret = *this; + CHECK(!is_none()) << "NDArray is not initialized"; CHECK_GE(shape_[0], end) << "Slice end index out of range"; size_t length = 1; if (shape_.ndim() == 1) { @@ -227,20 +227,20 @@ class NArray { return ret; } /*! - * \brief Get an reshaped NArray + * \brief Get an reshaped NDArray * \param shape new shape - * \return NArray in new shape + * \return NDArray in new shape */ - inline NArray Reshape(const TShape &shape) const { + inline NDArray Reshape(const TShape &shape) const { CHECK_GE(shape_.Size(), shape.Size()) - << "NArray.Reshape: target shape size is different from current shape"; - NArray ret = *this; + << "NDArray.Reshape: target shape size is different from current shape"; + NDArray ret = *this; ret.shape_ = shape; return ret; } private: - /*! \brief the real data chunk that backs NArray */ + /*! \brief the real data chunk that backs NDArray */ struct Chunk { /*! \brief storage handlefrom storage engine */ Storage::Handle shandle; @@ -294,120 +294,120 @@ class NArray { } } }; - /*! \brief internal data of NArray */ + /*! \brief internal data of NDArray */ std::shared_ptr ptr_; - /*! \brief shape of current NArray */ + /*! \brief shape of current NDArray */ TShape shape_; /*! \brief offset in chunk */ size_t offset_; // add friend to helper functions - friend void CopyFromTo(const NArray &from, NArray *to); + friend void CopyFromTo(const NDArray &from, NDArray *to); template - friend void BinaryOp(const NArray &lhs, const NArray &rhs, NArray *out); + friend void BinaryOp(const NDArray &lhs, const NDArray &rhs, NDArray *out); template - friend void UnaryOp(const NArray &lhs, const NArray &rhs, NArray *out); + friend void UnaryOp(const NDArray &lhs, const NDArray &rhs, NDArray *out); template - friend void ScalarOp(const NArray &lhs, const real_t &rhs, NArray *out); - friend void SetValueOp(const real_t &rhs, NArray *out); + friend void ScalarOp(const NDArray &lhs, const real_t &rhs, NDArray *out); + friend void SetValueOp(const real_t &rhs, NDArray *out); }; /*! - * \brief issue an copy operation from one NArray to another - * the two narray can sit on different devices + * \brief issue an copy operation from one NDArray to another + * the two ndarray can sit on different devices * this operation will be scheduled by the engine * * NOTE: this function name explicitly marks the order of from and to * due to different possible convention carried by copy function - * \param from the narray we want to copy data from - * \param to the target narray + * \param from the ndarray we want to copy data from + * \param to the target ndarray */ -void CopyFromTo(const NArray &from, NArray *to); +void CopyFromTo(const NDArray &from, NDArray *to); /*! * \brief elementwise add * \param lhs left operand * \param rhs right operand - * \return a new result narray + * \return a new result ndarray */ -NArray operator+(const NArray &lhs, const NArray &rhs); +NDArray operator+(const NDArray &lhs, const NDArray &rhs); /*! * \brief elementwise add * \param lhs left operand * \param rhs right operand - * \return a new result narray + * \return a new result ndarray */ -NArray operator+(const NArray &lhs, const real_t &rhs); +NDArray operator+(const NDArray &lhs, const real_t &rhs); /*! * \brief elementwise substraction * \param lhs left operand * \param rhs right operand - * \return a new result narray + * \return a new result ndarray */ -NArray operator-(const NArray &lhs, const NArray &rhs); +NDArray operator-(const NDArray &lhs, const NDArray &rhs); /*! * \brief elementwise substraction * \param lhs left operand * \param rhs right operand - * \return a new result narray + * \return a new result ndarray */ -NArray operator-(const NArray &lhs, const real_t &rhs); +NDArray operator-(const NDArray &lhs, const real_t &rhs); /*! * \brief elementwise multiplication * \param lhs left operand * \param rhs right operand - * \return a new result narray + * \return a new result ndarray */ -NArray operator*(const NArray &lhs, const NArray &rhs);\ +NDArray operator*(const NDArray &lhs, const NDArray &rhs);\ /*! * \brief elementwise multiplication * \param lhs left operand * \param rhs right operand - * \return a new result narray + * \return a new result ndarray */ -NArray operator*(const NArray &lhs, const real_t &rhs); +NDArray operator*(const NDArray &lhs, const real_t &rhs); /*! * \brief elementwise division * \param lhs left operand * \param rhs right operand - * \return a new result narray + * \return a new result ndarray */ -NArray operator/(const NArray &lhs, const NArray &rhs); +NDArray operator/(const NDArray &lhs, const NDArray &rhs); /*! * \brief elementwise division * \param lhs left operand * \param rhs right operand - * \return a new result narray + * \return a new result ndarray */ -NArray operator/(const NArray &lhs, const real_t &rhs); +NDArray operator/(const NDArray &lhs, const real_t &rhs); //-------------------------------------------------------------- -// The following part are API Registration of NArray functions. +// The following part are API Registration of NDArray functions. //-------------------------------------------------------------- -/*! \brief definition of NArray function */ -typedef std::function NArrayAPIFunction; + NDArray **mutate_vars)> NDArrayAPIFunction; /*! \brief mask information on how functions can be exposed */ -enum NArrayFunctionTypeMask { +enum NDArrayFunctionTypeMask { /*! \brief all the use_vars should go before scalar */ - kNArrayArgBeforeScalar = 1, + kNDArrayArgBeforeScalar = 1, /*! \brief all the scalar should go before use_vars */ - kScalarArgBeforeNArray = 1 << 1, + kScalarArgBeforeNDArray = 1 << 1, /*! * \brief whether this function allows the handles in the target to - * be empty NArray that are not yet initialized, and will initialize + * be empty NDArray that are not yet initialized, and will initialize * them when the function is invoked. * * most function should support this, except copy between different - * devices, which requires the NArray to be pre-initialized with context + * devices, which requires the NDArray to be pre-initialized with context */ kAcceptEmptyMutateTarget = 1 << 2 }; -/*! \brief Registry entry for NArrayFunction */ -struct NArrayFunctionReg - : public dmlc::FunctionRegEntryBase { +/*! \brief Registry entry for NDArrayFunction */ +struct NDArrayFunctionReg + : public dmlc::FunctionRegEntryBase { /*! \brief number of variable used by this function */ unsigned num_use_vars; /*! \brief number of variable mutated by this function */ @@ -419,81 +419,81 @@ struct NArrayFunctionReg /*! * \brief constructor */ - NArrayFunctionReg() + NDArrayFunctionReg() : num_use_vars(0), num_mutate_vars(0), num_scalars(0), type_mask(0) {} /*! - * \brief set the function body to a NArray setvalue function + * \brief set the function body to a NDArray setvalue function * this will also auto set the parameters correctly * \param fsetvalue function body to set * \return ref to the registered entry, used to set properties */ - inline NArrayFunctionReg &set_function(void fsetvalue(const real_t &rhs, - NArray *out)) { - body = [fsetvalue] (NArray **used_vars, - real_t *s, NArray **mutate_vars) { + inline NDArrayFunctionReg &set_function(void fsetvalue(const real_t &rhs, + NDArray *out)) { + body = [fsetvalue] (NDArray **used_vars, + real_t *s, NDArray **mutate_vars) { fsetvalue(s[0], mutate_vars[0]); }; num_mutate_vars = 1; num_scalars = 1; - // type_mask = kNArrayArgBeforeScalar; + // type_mask = kNDArrayArgBeforeScalar; this->add_argument("rhs", "real_t", "Right operand to the function."); return *this; } /*! - * \brief set the function body to a binary NArray function + * \brief set the function body to a binary NDArray function * this will also auto set the parameters correctly * \param fbinary function body to set * \return ref to the registered entry, used to set properties */ - inline NArrayFunctionReg &set_function(void fbinary(const NArray &lhs, - const NArray &rhs, - NArray *out)) { - body = [fbinary] (NArray **used_vars, - real_t *s, NArray **mutate_vars) { + inline NDArrayFunctionReg &set_function(void fbinary(const NDArray &lhs, + const NDArray &rhs, + NDArray *out)) { + body = [fbinary] (NDArray **used_vars, + real_t *s, NDArray **mutate_vars) { fbinary(*used_vars[0], *used_vars[1], mutate_vars[0]); }; num_use_vars = 2; num_mutate_vars = 1; - type_mask = kNArrayArgBeforeScalar | kAcceptEmptyMutateTarget; - this->add_argument("lhs", "NArray", "Left operand to the function."); - this->add_argument("rhs", "NArray", "Right operand to the function."); + type_mask = kNDArrayArgBeforeScalar | kAcceptEmptyMutateTarget; + this->add_argument("lhs", "NDArray", "Left operand to the function."); + this->add_argument("rhs", "NDArray", "Right operand to the function."); return *this; } /*! - * \brief set the function body to a binary NArray function + * \brief set the function body to a binary NDArray function * this will also auto set the parameters correctly * \param fscalar function body to set * \return ref to the registered entry, used to set properties */ - inline NArrayFunctionReg &set_function(void fscalar(const NArray &lhs, + inline NDArrayFunctionReg &set_function(void fscalar(const NDArray &lhs, const real_t &rhs, - NArray *out)) { - body = [fscalar] (NArray **used_vars, - real_t *s, NArray **mutate_vars) { + NDArray *out)) { + body = [fscalar] (NDArray **used_vars, + real_t *s, NDArray **mutate_vars) { fscalar(*used_vars[0], s[0], mutate_vars[0]); }; num_use_vars = 1; num_mutate_vars = 1; num_scalars = 1; - type_mask = kNArrayArgBeforeScalar | kAcceptEmptyMutateTarget; - this->add_argument("lhs", "NArray", "Left operand to the function."); + type_mask = kNDArrayArgBeforeScalar | kAcceptEmptyMutateTarget; + this->add_argument("lhs", "NDArray", "Left operand to the function."); this->add_argument("rhs", "real_t", "Right operand to the function."); return *this; } /*! - * \brief set the function body to a unary NArray function + * \brief set the function body to a unary NDArray function * this will also auto set the parameters correctly * \param funary function body to set * \return ref to the registered entry, used to set properties */ - inline NArrayFunctionReg &set_function(void funary(const NArray &src, - NArray *out)) { - body = [funary] (NArray **used_vars, - real_t *s, NArray **mutate_vars) { + inline NDArrayFunctionReg &set_function(void funary(const NDArray &src, + NDArray *out)) { + body = [funary] (NDArray **used_vars, + real_t *s, NDArray **mutate_vars) { funary(*used_vars[0], mutate_vars[0]); }; num_use_vars = 1; num_mutate_vars = 1; - type_mask = kNArrayArgBeforeScalar | kAcceptEmptyMutateTarget; - this->add_argument("src", "NArray", "Source input to the function."); + type_mask = kNDArrayArgBeforeScalar | kAcceptEmptyMutateTarget; + this->add_argument("src", "NDArray", "Source input to the function."); return *this; } /*! @@ -501,7 +501,7 @@ struct NArrayFunctionReg * \param n number of mutate variablesx * \return ref to the registered entry, used to set properties */ - inline NArrayFunctionReg &set_num_use_vars(unsigned n) { + inline NDArrayFunctionReg &set_num_use_vars(unsigned n) { num_use_vars = n; return *this; } /*! @@ -509,7 +509,7 @@ struct NArrayFunctionReg * \param n number of mutate variablesx * \return ref to the registered entry, used to set properties */ - inline NArrayFunctionReg &set_num_mutate_vars(unsigned n) { + inline NDArrayFunctionReg &set_num_mutate_vars(unsigned n) { num_mutate_vars = n; return *this; } /*! @@ -517,7 +517,7 @@ struct NArrayFunctionReg * \param n number of scalar arguments * \return ref to the registered entry, used to set properties */ - inline NArrayFunctionReg &set_num_scalars(unsigned n) { + inline NDArrayFunctionReg &set_num_scalars(unsigned n) { num_scalars = n; return *this; } /*! @@ -525,29 +525,29 @@ struct NArrayFunctionReg * \param tmask typemask * \return ref to the registered entry, used to set properties */ - inline NArrayFunctionReg &set_type_mask(int tmask) { + inline NDArrayFunctionReg &set_type_mask(int tmask) { type_mask = tmask; return *this; } -}; // NArrayFunctionReg +}; // NDArrayFunctionReg /*! - * \brief Macro to register NArray function + * \brief Macro to register NDArray function * * Example: the following code is example to register a plus * \code * - * REGISTER_NARRAY_FUN(Plus) + * REGISTER_NDARRAY_FUN(Plus) * .set_function(Plus); * * \endcode */ -#define MXNET_REGISTER_NARRAY_FUN(name) \ - DMLC_REGISTRY_REGISTER(::mxnet::NArrayFunctionReg, NArrayFunctionReg, name) +#define MXNET_REGISTER_NDARRAY_FUN(name) \ + DMLC_REGISTRY_REGISTER(::mxnet::NDArrayFunctionReg, NDArrayFunctionReg, name) } // namespace mxnet namespace dmlc { /*!\brief traits */ -DMLC_DECLARE_TRAITS(has_saveload, mxnet::NArray, true); +DMLC_DECLARE_TRAITS(has_saveload, mxnet::NDArray, true); } // namespace dmlc -#endif // MXNET_NARRAY_H_ +#endif // MXNET_NDARRAY_H_ diff --git a/include/mxnet/operator.h b/include/mxnet/operator.h index 0842f53d347e..ae6f6af45df8 100644 --- a/include/mxnet/operator.h +++ b/include/mxnet/operator.h @@ -338,7 +338,7 @@ class OperatorProperty { } /*! * \brief Get Backward Input Dependency for generic types of data. - * Normally T can be pointer of Symbol::DataEntry, or NArray. + * Normally T can be pointer of Symbol::DataEntry, or NDArray. * This function will select the result list of T according to DeclareBackwardDependency. * * \param in_data the input data in forward pass. diff --git a/include/mxnet/symbolic.h b/include/mxnet/symbolic.h index 28e82da32c06..e496ff42a673 100644 --- a/include/mxnet/symbolic.h +++ b/include/mxnet/symbolic.h @@ -16,7 +16,7 @@ #include #include #include "./base.h" -#include "./narray.h" +#include "./ndarray.h" #include "./operator.h" // check c++11 @@ -393,36 +393,36 @@ class Executor { /*! * \brief Perform a Backward operation of the Operator. * This must be called after Forward. - * After this operation, NArrays specified by grad_in_args_store will be updated accordingly. + * After this operation, NDArrays specified by grad_in_args_store will be updated accordingly. * User is allowed to pass in an empty Array if the head node is * loss function and head gradeitn is not needed. * * \param head_grads the gradient of head nodes to be backproped. */ - virtual void Backward(const std::vector &head_grads) = 0; + virtual void Backward(const std::vector &head_grads) = 0; /*! * \brief get array of heads in the executor. * \return array of heads in the executor. */ - virtual const std::vector &heads() const = 0; + virtual const std::vector &heads() const = 0; /*! * \brief Create an operator by bind symbol with context and arguments. * If user do not want to compute the gradients of i-th argument, grad_req_type[i] can be kNullOp. * * \param ctx the context of binding. * \param symbol the symbol that specifies the output of Forward pass. - * \param in_args the NArray that stores the input arguments to the symbol. - * \param arg_grad_store NArray that is used to store the gradient output of the input arguments. + * \param in_args the NDArray that stores the input arguments to the symbol. + * \param arg_grad_store NDArray that is used to store the gradient output of the input arguments. * \param grad_req_type requirment type of gradient saving. Can only be in {kNullOp, kAddTo, kWriteTo}. - * \param aux_states NArray that is used as internal state in op + * \param aux_states NDArray that is used as internal state in op * \return a new executor. */ static Executor *Bind(Symbol symbol, Context ctx, - const std::vector &in_args, - const std::vector &arg_grad_store, + const std::vector &in_args, + const std::vector &arg_grad_store, const std::vector &grad_req_type, - const std::vector &aux_states); + const std::vector &aux_states); }; // class operator } // namespace mxnet #endif // MXNET_SYMBOLIC_H_ diff --git a/make/config.mk b/make/config.mk index 3bc639ca1dba..73045cfc353d 100644 --- a/make/config.mk +++ b/make/config.mk @@ -24,17 +24,21 @@ USE_CUDA = 0 # if you have already add them to enviroment variable, leave it as NONE USE_CUDA_PATH = NONE -# whether use opencv during compilation -# you can disable it, however, you will not able to use -# imbin iterator -USE_OPENCV = 1 -USE_OPENCV_DECODER = 1 # whether use CUDNN R3 library USE_CUDNN = 0 + # add the path to CUDNN libary to link and compile flag # if you do not need that, or do not have that, leave it as NONE USE_CUDNN_PATH = NONE +# whether use opencv during compilation +# you can disable it, however, you will not able to use +# imbin iterator +USE_OPENCV = 1 + +# use openmp for parallelization +USE_OPENMP = 1 + # # choose the version of blas you want to use # can be: mkl, blas, atlas, openblas @@ -46,20 +50,6 @@ USE_BLAS = blas # USE_INTEL_PATH = NONE -# whether compile with parameter server -USE_DIST_PS = 0 -PS_PATH = NONE -PS_THIRD_PATH = NONE - -# whether compile with rabit -USE_RABIT_PS = 0 -RABIT_PATH = rabit - -# Whether to use threaded engine instead of naive one -# USE_THREADED_ENGINE =1 - -# use openmp iterator -USE_OPENMP_ITER = 1 # the additional link flags you want to add ADD_LDFLAGS = diff --git a/python/mxnet/__init__.py b/python/mxnet/__init__.py index 7c4246f25285..c591fc29510b 100644 --- a/python/mxnet/__init__.py +++ b/python/mxnet/__init__.py @@ -10,9 +10,11 @@ from .context import Context, current_context, cpu, gpu from .base import MXNetError -from . import narray +from . import ndarray from . import symbol from . import kvstore from . import io +# use mx.nd as short for mx.ndarray +from . import ndarray as nd __version__ = "0.1.0" diff --git a/python/mxnet/base.py b/python/mxnet/base.py index df91998e9a45..969456aa2a3d 100644 --- a/python/mxnet/base.py +++ b/python/mxnet/base.py @@ -72,7 +72,7 @@ def _load_lib(): # type definitions mx_uint = ctypes.c_uint mx_float = ctypes.c_float -NArrayHandle = ctypes.c_void_p +NDArrayHandle = ctypes.c_void_p FunctionHandle = ctypes.c_void_p SymbolCreatorHandle = ctypes.c_void_p SymbolHandle = ctypes.c_void_p @@ -164,7 +164,7 @@ def ctypes2numpy_shared(cptr, shape): pointer to the memory region shape : tuple - shape of target narray + shape of target ndarray Returns ------- diff --git a/python/mxnet/executor.py b/python/mxnet/executor.py index a3ba09ca1a76..189461074ffc 100644 --- a/python/mxnet/executor.py +++ b/python/mxnet/executor.py @@ -5,9 +5,9 @@ import ctypes from .base import _LIB -from .base import c_array, mx_uint, NArrayHandle, ExecutorHandle +from .base import c_array, mx_uint, NDArrayHandle, ExecutorHandle from .base import check_call -from .narray import NArray +from .ndarray import NDArray class Executor(object): """ Executor is the actual executing object of MXNet.""" @@ -22,8 +22,8 @@ def __init__(self, handle): if not isinstance(handle, ExecutorHandle): raise TypeError("Handle type error") self.handle = handle - self.arg_narrays = [] - self.grad_narrays = [] + self.arg_ndarrays = [] + self.grad_ndarrays = [] self.auxiliary_states = [] def list_arguments(self, with_grad=True): @@ -41,9 +41,9 @@ def list_arguments(self, with_grad=True): Note: args sequence is same to symbol.list_arguments() """ if with_grad: - return self.arg_narrays, self.grad_narrays + return self.arg_ndarrays, self.grad_ndarrays else: - return self.arg_narrays + return self.arg_ndarrays def list_auxiliary_states(self): """Return auxiliary states of executor @@ -67,32 +67,32 @@ def backward(self, head_grads=None): Parameters ---------- - head_grads : NArray or list of NArray, optional + head_grads : NDArray or list of NDArray, optional Gradient on the heads """ if head_grads is None: head_grads = [] - elif isinstance(head_grads, NArray): + elif isinstance(head_grads, NDArray): head_grads = [head_grads] for obj in head_grads: - if not isinstance(obj, NArray): - raise TypeError("inputs must be NArray") - narray = c_array(NArrayHandle, [item.handle for item in head_grads]) - check_call(_LIB.MXExecutorBackward(self.handle, len(head_grads), narray)) + if not isinstance(obj, NDArray): + raise TypeError("inputs must be NDArray") + ndarray = c_array(NDArrayHandle, [item.handle for item in head_grads]) + check_call(_LIB.MXExecutorBackward(self.handle, len(head_grads), ndarray)) def heads(self): - """list all heads' output narray + """list all heads' output ndarray Returns ------- - A list of narray binded to the heads of executor. + A list of ndarray binded to the heads of executor. """ # TODO: think of access, make heads read only. - # (consider support read only NArray(NArrayView)) + # (consider support read only NDArray(NDArrayView)) # Otherwise some of the internal might depends on out_data # if user set the content of the head, the backward behavior can be incorrect. out_size = mx_uint() - handles = ctypes.POINTER(NArrayHandle)() + handles = ctypes.POINTER(NDArrayHandle)() check_call(_LIB.MXExecutorHeads(self.handle, ctypes.byref(out_size), ctypes.byref(handles))) - return [NArray(NArrayHandle(handles[i])) for i in range(out_size.value)] + return [NDArray(NDArrayHandle(handles[i])) for i in range(out_size.value)] diff --git a/python/mxnet/io.py b/python/mxnet/io.py index 58dbe6e6f9a3..aff8e1c8cb00 100644 --- a/python/mxnet/io.py +++ b/python/mxnet/io.py @@ -1,15 +1,15 @@ # coding: utf-8 -"""NArray interface of mxnet""" +"""NDArray interface of mxnet""" from __future__ import absolute_import import ctypes import sys from .base import _LIB from .base import c_array, c_str, mx_uint, py_str -from .base import DataIterHandle, NArrayHandle +from .base import DataIterHandle, NDArrayHandle from .base import check_call -from .narray import NArray +from .ndarray import NDArray class DataIter(object): """DataIter object in mxnet. List all the needed functions here. """ @@ -71,17 +71,17 @@ def getdata(self): """get data from batch """ - hdl = NArrayHandle() + hdl = NDArrayHandle() check_call(_LIB.MXDataIterGetData(self.handle, ctypes.byref(hdl))) - return NArray(hdl) + return NDArray(hdl) def getlabel(self): """get label from batch """ - hdl = NArrayHandle() + hdl = NDArrayHandle() check_call(_LIB.MXDataIterGetLabel(self.handle, ctypes.byref(hdl))) - return NArray(hdl) + return NDArray(hdl) def _make_io_iterator(handle): """Create an io iterator by handle.""" diff --git a/python/mxnet/kvstore.py b/python/mxnet/kvstore.py index 61fadb739815..1132ac89e62e 100644 --- a/python/mxnet/kvstore.py +++ b/python/mxnet/kvstore.py @@ -3,21 +3,21 @@ """ KVStore in mxnet """ from __future__ import absolute_import import ctypes -from .narray import NArray +from .ndarray import NDArray from .base import _LIB -from .base import check_call, c_array, NArrayHandle +from .base import check_call, c_array, NDArrayHandle def _ctype_key_value(keys, vals): """parse key-value args into ctype""" if isinstance(keys, int): - if isinstance(vals, NArray): + if isinstance(vals, NDArray): return (c_array(ctypes.c_int, [keys]), - c_array(NArrayHandle, [vals.handle])) + c_array(NDArrayHandle, [vals.handle])) else: for v in vals: - assert(isinstance(v, NArray)) + assert(isinstance(v, NDArray)) return (c_array(ctypes.c_int, [keys] * len(vals)), - c_array(NArrayHandle, [v.handle for v in vals])) + c_array(NDArrayHandle, [v.handle for v in vals])) else: assert(len(keys) == len(vals)) for k in keys: @@ -28,7 +28,7 @@ def _ctype_key_value(keys, vals): c_key_i, c_val_i = _ctype_key_value(keys[i], vals[i]) c_keys += c_key_i c_vals += c_val_i - return (c_array(ctypes.c_int, c_keys), c_array(NArrayHandle, c_vals)) + return (c_array(ctypes.c_int, c_keys), c_array(NDArrayHandle, c_vals)) def start(): """start kvstore""" @@ -46,7 +46,7 @@ def init(key, value): ---------- keys: int or list of int A single key or a list of keys - values: NArray or list of NArray + values: NDArray or list of NDArray A single value of a list of values """ ckeys, cvals = _ctype_key_value(key, value) @@ -59,7 +59,7 @@ def push(key, value): ---------- key : int or list of int A single key or a list of key - value: list of NArray or list of list of NArray + value: list of NDArray or list of list of NDArray A single value of a list of value """ ckeys, cvals = _ctype_key_value(key, value) @@ -72,7 +72,7 @@ def pull(key, out=None): ---------- key: int or list of int A single key or a list of key - out: NArray or list of NArray + out: NDArray or list of NDArray A single value of a list of value """ assert(out is not None) @@ -85,8 +85,8 @@ def _updater_wrapper(updater): """ a wrapper for the user-defined handle """ def updater_handle(key, lhs_handle, rhs_handle): """ ctypes function """ - lhs = NArray(NArrayHandle(lhs_handle)) - rhs = NArray(NArrayHandle(rhs_handle)) + lhs = NDArray(NDArrayHandle(lhs_handle)) + rhs = NDArray(NDArrayHandle(rhs_handle)) updater(key, lhs, rhs) return updater_handle @@ -106,7 +106,7 @@ def updater(recv, local): updater: functon """ _updater_proto = ctypes.CFUNCTYPE( - None, ctypes.c_int, NArrayHandle, NArrayHandle) + None, ctypes.c_int, NDArrayHandle, NDArrayHandle) global _updater_func _updater_func = _updater_proto(_updater_wrapper(updater)) check_call(_LIB.MXKVStoreSetUpdater(_updater_func)) diff --git a/python/mxnet/narray.py b/python/mxnet/ndarray.py similarity index 64% rename from python/mxnet/narray.py rename to python/mxnet/ndarray.py index 208fd8e17d7a..4bd28f814d5d 100644 --- a/python/mxnet/narray.py +++ b/python/mxnet/ndarray.py @@ -1,5 +1,5 @@ # coding: utf-8 -"""NArray interface of mxnet""" +"""NDArray interface of mxnet""" from __future__ import absolute_import import ctypes @@ -8,7 +8,7 @@ import numpy as np from .base import _LIB, string_types, numeric_types from .base import c_array, py_str, c_str -from .base import mx_uint, mx_float, NArrayHandle, FunctionHandle +from .base import mx_uint, mx_float, NDArrayHandle, FunctionHandle from .base import ctypes2buffer from .base import check_call from .context import Context @@ -20,10 +20,10 @@ def _new_empty_handle(): Returns ------- - a new empty narray handle + a new empty ndarray handle """ - hdl = NArrayHandle() - check_call(_LIB.MXNArrayCreateNone(ctypes.byref(hdl))) + hdl = NDArrayHandle() + check_call(_LIB.MXNDArrayCreateNone(ctypes.byref(hdl))) return hdl def _new_alloc_handle(shape, ctx, delay_alloc): @@ -33,10 +33,10 @@ def _new_alloc_handle(shape, ctx, delay_alloc): Returns ------- - a new empty narray handle + a new empty ndarray handle """ - hdl = NArrayHandle() - check_call(_LIB.MXNArrayCreate( + hdl = NDArrayHandle() + check_call(_LIB.MXNDArrayCreate( c_array(mx_uint, shape), len(shape), ctx.device_mask, @@ -45,39 +45,39 @@ def _new_alloc_handle(shape, ctx, delay_alloc): ctypes.byref(hdl))) return hdl -class NArray(object): - """NArray object in mxnet. +class NDArray(object): + """NDArray object in mxnet. - NArray is basic ndarray/Tensor like data structure in mxnet. + NDArray is basic ndarray/Tensor like data structure in mxnet. """ # pylint: disable= no-member def __init__(self, handle): - """initialize a new NArray + """initialize a new NDArray Parameters ---------- - handle : NArrayHandle - NArray handle of C API + handle : NDArrayHandle + NDArray handle of C API """ - assert isinstance(handle, NArrayHandle) + assert isinstance(handle, NDArrayHandle) self.handle = handle def __del__(self): - check_call(_LIB.MXNArrayFree(self.handle)) + check_call(_LIB.MXNDArrayFree(self.handle)) def __add__(self, other): - if isinstance(other, NArray): - return NArray._plus(self, other) + if isinstance(other, NDArray): + return NDArray._plus(self, other) elif isinstance(other, numeric_types): - return NArray._plus_scalar(self, float(other)) + return NDArray._plus_scalar(self, float(other)) else: raise TypeError('type %s not supported' % str(type(other))) def __iadd__(self, other): - if isinstance(other, NArray): - return NArray._plus(self, other, out=self) + if isinstance(other, NDArray): + return NDArray._plus(self, other, out=self) elif isinstance(other, numeric_types): - return NArray._plus_scalar(self, float(other), out=self) + return NDArray._plus_scalar(self, float(other), out=self) else: raise TypeError('type %s not supported' % str(type(other))) @@ -85,43 +85,43 @@ def __radd__(self, other): return self.__add__(other) def __sub__(self, other): - if isinstance(other, NArray): - return NArray._minus(self, other) + if isinstance(other, NDArray): + return NDArray._minus(self, other) elif isinstance(other, numeric_types): - return NArray._minus_scalar(self, float(other)) + return NDArray._minus_scalar(self, float(other)) else: raise TypeError('type %s not supported' % str(type(other))) def __isub__(self, other): - if isinstance(other, NArray): - return NArray._minus(self, other, out=self) + if isinstance(other, NDArray): + return NDArray._minus(self, other, out=self) elif isinstance(other, numeric_types): - return NArray._minus_scalar(self, float(other), out=self) + return NDArray._minus_scalar(self, float(other), out=self) else: raise TypeError('type %s not supported' % str(type(other))) def __rsub__(self, other): if isinstance(other, numeric_types): - return NArray._rminus_scalar(self, float(other)) + return NDArray._rminus_scalar(self, float(other)) else: raise TypeError('type %s not supported' % str(type(other))) def __mul__(self, other): - if isinstance(other, NArray): - return NArray._mul(self, other) + if isinstance(other, NDArray): + return NDArray._mul(self, other) elif isinstance(other, numeric_types): - return NArray._mul_scalar(self, float(other)) + return NDArray._mul_scalar(self, float(other)) else: raise TypeError('type %s not supported' % str(type(other))) def __neg__(self): - return NArray._mul_scalar(self, -1.0, out=self) + return NDArray._mul_scalar(self, -1.0, out=self) def __imul__(self, other): - if isinstance(other, NArray): - return NArray._mul(self, other, out=self) + if isinstance(other, NDArray): + return NDArray._mul(self, other, out=self) elif isinstance(other, numeric_types): - return NArray._mul_scalar(self, float(other), out=self) + return NDArray._mul_scalar(self, float(other), out=self) else: raise TypeError('type %s not supported' % str(type(other))) @@ -129,24 +129,24 @@ def __rmul__(self, other): return self.__mul__(other) def __div__(self, other): - if isinstance(other, NArray): - return NArray._div(self, other) + if isinstance(other, NDArray): + return NDArray._div(self, other) elif isinstance(other, numeric_types): - return NArray._div_scalar(self, float(other)) + return NDArray._div_scalar(self, float(other)) else: raise TypeError('type %s not supported' % str(type(other))) def __rdiv__(self, other): if isinstance(other, numeric_types): - return NArray._rdiv_scalar(self, float(other)) + return NDArray._rdiv_scalar(self, float(other)) else: raise TypeError('type %s not supported' % str(type(other))) def __idiv__(self, other): - if isinstance(other, NArray): - return NArray._div(self, other, out=self) + if isinstance(other, NDArray): + return NDArray._div(self, other, out=self) elif isinstance(other, numeric_types): - return NArray._div_scalar(self, float(other), out=self) + return NDArray._div_scalar(self, float(other), out=self) else: raise TypeError('type %s not supported' % str(type(other))) @@ -159,7 +159,7 @@ def __getstate__(self): if handle is not None: length = ctypes.c_ulong() cptr = ctypes.POINTER(ctypes.c_char)() - check_call(_LIB.MXNArraySaveRawBytes(self.handle, + check_call(_LIB.MXNDArraySaveRawBytes(self.handle, ctypes.byref(length), ctypes.byref(cptr))) this['handle'] = ctypes2buffer(cptr, length.value) @@ -169,31 +169,31 @@ def __setstate__(self, state): handle = state['handle'] if handle is not None: buf = handle - handle = NArrayHandle() + handle = NDArrayHandle() ptr = (ctypes.c_char * len(buf)).from_buffer(buf) length = ctypes.c_ulong(len(buf)) - check_call(_LIB.MXNArrayLoadFromRawBytes(ptr, length, ctypes.byref(handle))) + check_call(_LIB.MXNDArrayLoadFromRawBytes(ptr, length, ctypes.byref(handle))) state['handle'] = handle self.__dict__.update(state) def __setitem__(self, in_slice, value): - """Set narray value""" + """Set ndarray value""" if in_slice.step != None: - raise Exception("Set NArray should use empty index array[:] = target_array") - if isinstance(value, NArray): + raise Exception("Set NDArray should use empty index array[:] = target_array") + if isinstance(value, NDArray): if value.handle is not self.handle: value.copyto(self) elif isinstance(value, numeric_types): - NArray._set_value(float(value), out=self) + NDArray._set_value(float(value), out=self) elif isinstance(value, (np.ndarray, np.generic)): self._sync_copyfrom(value) else: raise TypeError('type %s not supported' % str(type(value))) def __getitem__(self, in_slice): - """Get narray""" + """Get ndarray""" if in_slice.step != None: - raise Exception("Set NArray should use empty index array[:] += value") + raise Exception("Set NDArray should use empty index array[:] += value") return self def _sync_copyfrom(self, source_array): @@ -213,57 +213,57 @@ def _sync_copyfrom(self, source_array): source_array = np.ascontiguousarray(source_array, dtype=np.float32) if source_array.shape != self.shape: - raise ValueError('array shape do not match the shape of NArray') + raise ValueError('array shape do not match the shape of NDArray') - check_call(_LIB.MXNArraySyncCopyFromCPU( + check_call(_LIB.MXNDArraySyncCopyFromCPU( self.handle, source_array.ctypes.data_as(ctypes.POINTER(mx_float)), source_array.size)) def wait_to_read(self): - """Block until all pending writes operations on current NArray are finished. + """Block until all pending writes operations on current NDArray are finished. This function will return when all the pending writes to the current - NArray finishes. There can still be pending read going on when the + NDArray finishes. There can still be pending read going on when the function returns. """ - check_call(_LIB.MXNArrayWaitToRead(self.handle)) + check_call(_LIB.MXNDArrayWaitToRead(self.handle)) def wait_to_write(self): - """Block until all pending read/write operations on current NArray are finished. + """Block until all pending read/write operations on current NDArray are finished. This function will return when all the pending writes to the current - NArray finishes. There can still be pending read going on when the + NDArray finishes. There can still be pending read going on when the function returns. """ - check_call(_LIB.MXNArrayWaitToWrite(self.handle)) + check_call(_LIB.MXNDArrayWaitToWrite(self.handle)) @property def shape(self): - """Get shape of current NArray. + """Get shape of current NDArray. Returns ------- - a tuple representing shape of current narray + a tuple representing shape of current ndarray """ ndim = mx_uint() pdata = ctypes.POINTER(mx_uint)() - check_call(_LIB.MXNArrayGetShape( + check_call(_LIB.MXNDArrayGetShape( self.handle, ctypes.byref(ndim), ctypes.byref(pdata))) return tuple(pdata[:ndim.value]) @property def context(self): - """Get context of current NArray. + """Get context of current NDArray. Returns ------- context : mxnet.Context - The context of current NArray. + The context of current NDArray. """ dev_mask = ctypes.c_int() dev_id = ctypes.c_int() - check_call(_LIB.MXNArrayGetContext( + check_call(_LIB.MXNDArrayGetContext( self.handle, ctypes.byref(dev_mask), ctypes.byref(dev_id))) return Context(Context.devmask2type[dev_mask.value], dev_id.value) @@ -276,7 +276,7 @@ def asnumpy(self): A copy of array content. """ data = np.empty(self.shape, dtype=np.float32) - check_call(_LIB.MXNArraySyncCopyToCPU( + check_call(_LIB.MXNDArraySyncCopyToCPU( self.handle, data.ctypes.data, data.size)) @@ -285,89 +285,89 @@ def asnumpy(self): def copyto(self, other): """Copy the content of current array to other. - When other is NArray, the content is copied over. - When other is a Context, a new NArray in the context + When other is NDArray, the content is copied over. + When other is a Context, a new NDArray in the context will be created as target Parameters ---------- - other : NArray or Context - Target Narray or context we want to copy data to. + other : NDArray or Context + Target NDArray or context we want to copy data to. Returns ------- - dst : NArray - The copy target NArray + dst : NDArray + The copy target NDArray """ - if isinstance(other, NArray): + if isinstance(other, NDArray): if other.handle is self.handle: warnings.warn('copy an array to itself, is it intended?', RuntimeWarning) return - return NArray._copyto(self, out=other) + return NDArray._copyto(self, out=other) elif isinstance(other, Context): - hret = NArray(_new_alloc_handle(self.shape, other, True)) - return NArray._copyto(self, out=hret) + hret = NDArray(_new_alloc_handle(self.shape, other, True)) + return NDArray._copyto(self, out=hret) else: raise TypeError('copyto do not support type ' + type(other)) # pylint: enable= no-member def empty(shape, ctx=None): - """Create an empty uninitialized new NArray, with specified shape. + """Create an empty uninitialized new NDArray, with specified shape. Parameters ---------- shape : tuple - shape of the NArray. + shape of the NDArray. ctx : Context, optional - The context of the NArray, default to current default context. + The context of the NDArray, default to current default context. Returns ------- out: Array - The created NArray. + The created NDArray. """ if ctx is None: ctx = Context.default_ctx - return NArray(handle=_new_alloc_handle(shape, ctx, False)) + return NDArray(handle=_new_alloc_handle(shape, ctx, False)) def zeros(shape, ctx=None): - """Create a new NArray filled with 0, with specified shape. + """Create a new NDArray filled with 0, with specified shape. Parameters ---------- shape : tuple - shape of the NArray. + shape of the NDArray. ctx : Context, optional - The context of the NArray, default to current default context. + The context of the NDArray, default to current default context. Returns ------- out: Array - The created NArray. + The created NDArray. """ arr = empty(shape, ctx) arr[:] = 0.0 return arr def ones(shape, ctx=None): - """Create a new NArray filled with 1, with specified shape. + """Create a new NDArray filled with 1, with specified shape. Parameters ---------- shape : tuple - shape of the NArray. + shape of the NDArray. ctx : Context, optional - The context of the NArray, default to current default context. + The context of the NDArray, default to current default context. Returns ------- out: Array - The created NArray. + The created NDArray. """ arr = empty(shape, ctx) arr[:] = 1.0 @@ -375,20 +375,20 @@ def ones(shape, ctx=None): def array(source_array, ctx=None): - """Create a new NArray that copies content from source_array. + """Create a new NDArray that copies content from source_array. Parameters ---------- source_array : array_like - Source data to create NArray from. + Source data to create NDArray from. ctx : Context, optional - The context of the NArray, default to current default context. + The context of the NDArray, default to current default context. Returns ------- out: Array - The created NArray. + The created NDArray. """ if not isinstance(source_array, np.ndarray): @@ -402,7 +402,7 @@ def array(source_array, ctx=None): def load(fname): - """Load narray from binary file. + """Load ndarray from binary file. You can also use pickle to do the job if you only work on python. The advantage of load/save is the file is language agnostic. @@ -415,30 +415,30 @@ def load(fname): Returns ------- - out : list of NArray or dict of str to NArray - List of NArray or dict of str->NArray, depending on what was saved. + out : list of NDArray or dict of str to NDArray + List of NDArray or dict of str->NDArray, depending on what was saved. """ if not isinstance(fname, string_types): raise TypeError('fname need to be string') out_size = mx_uint() out_name_size = mx_uint() - handles = ctypes.POINTER(NArrayHandle)() + handles = ctypes.POINTER(NDArrayHandle)() names = ctypes.POINTER(ctypes.c_char_p)() - check_call(_LIB.MXNArrayListLoad(c_str(fname), + check_call(_LIB.MXNDArrayListLoad(c_str(fname), ctypes.byref(out_size), ctypes.byref(handles), ctypes.byref(out_name_size), ctypes.byref(names))) if out_name_size.value == 0: - return [NArray(NArrayHandle(handles[i])) for i in range(out_size.value)] + return [NDArray(NDArrayHandle(handles[i])) for i in range(out_size.value)] else: assert out_name_size.value == out_size.value return dict( - (py_str(names[i]), NArray(NArrayHandle(handles[i]))) for i in range(out_size.value)) + (py_str(names[i]), NDArray(NDArrayHandle(handles[i]))) for i in range(out_size.value)) def save(fname, data): - """Save list of NArray or dict of str->NArray to binary file. + """Save list of NDArray or dict of str->NDArray to binary file. You can also use pickle to do the job if you only work on python. The advantage of load/save is the file is language agnostic. @@ -449,7 +449,7 @@ def save(fname, data): fname : str The name of the file - data : list of NArray or dict of str to NArray + data : list of NDArray or dict of str to NDArray The data to be saved. """ handles = [] @@ -457,30 +457,30 @@ def save(fname, data): keys = [] for key, val in data.items(): if not isinstance(key, string_types): - raise TypeError('save only accept dict str->NArray or list of NArray') - if not isinstance(val, NArray): - raise TypeError('save only accept dict str->NArray or list of NArray') + raise TypeError('save only accept dict str->NDArray or list of NDArray') + if not isinstance(val, NDArray): + raise TypeError('save only accept dict str->NDArray or list of NDArray') keys.append(c_str(key)) handles.append(val.handle) keys = c_array(ctypes.c_char_p, keys) else: for val in data: - if not isinstance(val, NArray): - raise TypeError('save only accept dict str->NArray or list of NArray') + if not isinstance(val, NDArray): + raise TypeError('save only accept dict str->NDArray or list of NDArray') handles.append(val.handle) keys = None - check_call(_LIB.MXNArrayListSave(c_str(fname), + check_call(_LIB.MXNDArrayListSave(c_str(fname), len(handles), - c_array(NArrayHandle, handles), + c_array(NDArrayHandle, handles), keys)) # pylint: disable=too-many-locals, invalid-name -def _make_narray_function(handle): - """Create a NArray function from the FunctionHandle.""" - NARRAY_ARG_BEFORE_SCALAR = 1 +def _make_ndarray_function(handle): + """Create a NDArray function from the FunctionHandle.""" + NDARRAY_ARG_BEFORE_SCALAR = 1 ACCEPT_EMPTY_MUTATE_TARGET = 1 << 2 - # Get the property of NArray + # Get the property of NDArray n_mutate_vars = 0 n_used_vars = mx_uint() n_scalars = mx_uint() @@ -498,7 +498,7 @@ def _make_narray_function(handle): type_mask = type_mask.value accept_empty_mutate = (type_mask & ACCEPT_EMPTY_MUTATE_TARGET) != 0 # infer type of the function - if (type_mask & NARRAY_ARG_BEFORE_SCALAR) != 0: + if (type_mask & NDARRAY_ARG_BEFORE_SCALAR) != 0: use_vars_range = range(0, n_used_vars) scalar_range = range(n_used_vars, n_used_vars + n_scalars) else: @@ -532,97 +532,97 @@ def _make_narray_function(handle): 'Parameters\n' + '----------\n' + '%s\n' + - 'out : NArray, optional\n' + - ' The output NArray to hold the result.\n\n'+ + 'out : NDArray, optional\n' + + ' The output NDArray to hold the result.\n\n'+ 'Returns\n' + '-------\n' + - 'out : NArray\n'+ + 'out : NDArray\n'+ ' The output of binary function.') doc_str = doc_str % (py_str(desc.value), '\n'.join(param_str)) # Definition of internal functions. - def binary_narray_function(lhs, rhs, out=None): + def binary_ndarray_function(lhs, rhs, out=None): """Internal binary function """ if out: - if isinstance(out, NArray) == False: - raise TypeError('out must be NArray') + if isinstance(out, NDArray) == False: + raise TypeError('out must be NDArray') else: if not accept_empty_mutate: raise TypeError('argument out is required to call %s' % func_name) - out = NArray(_new_empty_handle()) + out = NDArray(_new_empty_handle()) check_call(_LIB.MXFuncInvoke(handle, - c_array(NArrayHandle, (lhs.handle, rhs.handle)), + c_array(NDArrayHandle, (lhs.handle, rhs.handle)), c_array(mx_float, ()), - c_array(NArrayHandle, (out.handle,)))) + c_array(NDArrayHandle, (out.handle,)))) return out - def unary_narray_function(src, out=None): - """internal NArray function""" + def unary_ndarray_function(src, out=None): + """internal NDArray function""" if out: - if isinstance(out, NArray) == False: - raise TypeError('out must be NArray') + if isinstance(out, NDArray) == False: + raise TypeError('out must be NDArray') else: if not accept_empty_mutate: raise TypeError('argument out is required to call %s' % func_name) - out = NArray(_new_empty_handle()) + out = NDArray(_new_empty_handle()) check_call(_LIB.MXFuncInvoke( \ handle, \ - c_array(NArrayHandle, (src.handle)), \ + c_array(NDArrayHandle, (src.handle)), \ c_array(mx_float, ()), \ - c_array(NArrayHandle, (out.handle,)))) + c_array(NDArrayHandle, (out.handle,)))) return out - def generic_narray_function(*args, **kwargs): + def generic_ndarray_function(*args, **kwargs): """Invoke this function by passing in parameters Parameters ---------- *args - Positional arguments of input scalars and NArray - out : NArray or tuple of NArray, optional - Output NArray, used to hold the output result. + Positional arguments of input scalars and NDArray + out : NDArray or tuple of NDArray, optional + Output NDArray, used to hold the output result. Returns ------- - out : NArray - The result NArray(tuple) of result of computation. + out : NDArray + The result NDArray(tuple) of result of computation. """ if 'out' in kwargs: mutate_vars = kwargs['out'] - if isinstance(mutate_vars, NArray): + if isinstance(mutate_vars, NDArray): mutate_vars = (mutate_vars,) if len(mutate_vars) != n_mutate_vars: raise TypeError('expect %d out in %s', n_mutate_vars, func_name) else: if accept_empty_mutate: mutate_vars = tuple( - NArray(_new_empty_handle()) for i in range(n_mutate_vars)) + NDArray(_new_empty_handle()) for i in range(n_mutate_vars)) else: raise TypeError('argument out is required to call %s' % func_name) check_call(_LIB.MXFuncInvoke( \ handle, \ - c_array(NArrayHandle, [args[i].handle for i in use_vars_range]), \ + c_array(NDArrayHandle, [args[i].handle for i in use_vars_range]), \ c_array(mx_float, [args[i] for i in scalar_range]), \ - c_array(NArrayHandle, [v.handle for v in mutate_vars]))) + c_array(NDArrayHandle, [v.handle for v in mutate_vars]))) if n_mutate_vars == 1: return mutate_vars[0] else: return mutate_vars # End of function declaration if n_mutate_vars == 1 and n_used_vars == 2 and n_scalars == 0: - ret_function = binary_narray_function + ret_function = binary_ndarray_function elif n_mutate_vars == 1 and n_used_vars == 2 and n_scalars == 0: - ret_function = unary_narray_function + ret_function = unary_ndarray_function else: - ret_function = generic_narray_function + ret_function = generic_ndarray_function ret_function.__name__ = func_name ret_function.__doc__ = doc_str return ret_function # pylint: enable=too-many-locals, invalid-name -def _init_narray_module(): - """List and add all the narray functions to current module.""" +def _init_ndarray_module(): + """List and add all the ndarray functions to current module.""" plist = ctypes.POINTER(FunctionHandle)() size = ctypes.c_uint() check_call(_LIB.MXListFunctions(ctypes.byref(size), @@ -631,12 +631,12 @@ def _init_narray_module(): module_obj = sys.modules[__name__] for i in range(size.value): hdl = FunctionHandle(plist[i]) - function = _make_narray_function(hdl) - # if function name starts with underscore, register as static method of NArray + function = _make_ndarray_function(hdl) + # if function name starts with underscore, register as static method of NDArray if function.__name__.startswith('_'): - setattr(NArray, function.__name__, staticmethod(function)) + setattr(NDArray, function.__name__, staticmethod(function)) else: setattr(module_obj, function.__name__, function) -# Initialize the NArray module -_init_narray_module() +# Initialize the NDArray module +_init_ndarray_module() diff --git a/python/mxnet/symbol.py b/python/mxnet/symbol.py index 4cf15d7c60f5..75a94d0e5ab4 100644 --- a/python/mxnet/symbol.py +++ b/python/mxnet/symbol.py @@ -7,10 +7,10 @@ import sys from .base import _LIB from .base import c_array, c_str, mx_uint, py_str, string_types -from .base import NArrayHandle, ExecutorHandle, SymbolHandle +from .base import NDArrayHandle, ExecutorHandle, SymbolHandle from .base import check_call from .context import Context -from .narray import NArray, zeros +from .ndarray import NDArray, zeros from .executor import Executor @@ -281,19 +281,19 @@ def debug_str(self): return py_str(debug_str.value) @staticmethod - def _get_narray_handle(arg_key, args, arg_names, allow_missing): - """Helper function to get narray handles from various inputs. + def _get_ndarray_handle(arg_key, args, arg_names, allow_missing): + """Helper function to get ndarray handles from various inputs. Parameters ---------- arg_key : str The name of argument, used for error message. - args : list of NArray or dict of str->NArray + args : list of NDArray or dict of str->NDArray Input arguments to the symbols. - If type is list of NArray, the position is in the same order of arg_names. - If type is dict of str->NArray, then it maps the name of arguments - to the corresponding NArray, + If type is list of NDArray, the position is in the same order of arg_names. + If type is dict of str->NDArray, then it maps the name of arguments + to the corresponding NDArray, args_names : list of string List of argument names. @@ -304,8 +304,8 @@ def _get_narray_handle(arg_key, args, arg_names, allow_missing): Returns ------- - handles : list of NArrayHandle - The positional list of NArrayHandles generated from input. + handles : list of NDArrayHandle + The positional list of NDArrayHandles generated from input. """ # setup args arg_handles = [] @@ -313,15 +313,15 @@ def _get_narray_handle(arg_key, args, arg_names, allow_missing): if len(args) != len(arg_names): raise ValueError('Length of %s do not match number of arguments' % arg_key) for narr in args: - if not isinstance(narr, NArray): - raise TypeError('Only Accept list of NArrays or dict of str->NArray') + if not isinstance(narr, NDArray): + raise TypeError('Only Accept list of NDArrays or dict of str->NDArray') arg_handles.append(narr.handle) elif isinstance(args, dict): for name in arg_names: if name in arg_names: narr = args[name] - if not isinstance(narr, NArray): - raise TypeError('Only Accept list of NArrays or dict of str->NArray') + if not isinstance(narr, NDArray): + raise TypeError('Only Accept list of NDArrays or dict of str->NDArray') arg_handles.append(narr.handle) else: if allow_missing: @@ -329,11 +329,16 @@ def _get_narray_handle(arg_key, args, arg_names, allow_missing): else: raise ValueError('Must specify all the arguments in %s' % arg_key) else: - raise TypeError('Only Accept list of NArrays or dict of str->NArray') - return c_array(NArrayHandle, arg_handles) + raise TypeError('Only Accept list of NDArrays or dict of str->NDArray') + return c_array(NDArrayHandle, arg_handles) def simple_bind(self, ctx, grad_req='write', **kwargs): - """Simply bind current symbol to get an executor + """Simply bind current symbol to get an executor. + + This function will ask user to pass in ndarray of position + they like to bind to, and it will automatically allocate the ndarray + for arguments and auxiliary states that user did not specify explicitly. + Parameters ---------- ctx : Context @@ -341,14 +346,11 @@ def simple_bind(self, ctx, grad_req='write', **kwargs): grad_req: string {'write', 'add', 'null'}, or list of str or dict of str->str, optional Specifies how we should update the gradient to the args_grad. - - 'write' means everytime gradient is write to specified args_grad NArray. - - 'add' means everytime gradient is add to the specified NArray. + - 'write' means everytime gradient is write to specified args_grad NDArray. + - 'add' means everytime gradient is add to the specified NDArray. - 'null' means no action is taken, the gradient may not be calculated. - kwargs : dict of str->NArray - Input arguments to the symbol. - - type is dict of str->NArray, then it maps the name of arguments - to the corresponding NArray, - - Not all the arguments must be provided. + kwargs : dict of str->NDArray + Returns ------- executor : mxnet.Executor @@ -361,17 +363,17 @@ def simple_bind(self, ctx, grad_req='write', **kwargs): if arg_shapes == None: raise ValueError("Input node is not complete") # alloc space - arg_narrays = [] + arg_ndarrays = [] for name, shape in zip(self.list_arguments(), arg_shapes): if name in kwargs: - arg_narrays.append(kwargs[name]) + arg_ndarrays.append(kwargs[name]) else: - arg_narrays.append(zeros(shape, ctx)) + arg_ndarrays.append(zeros(shape, ctx)) # TODO(bing): specail treat input data grad # TODO(bing): not generate grad case - grad_narrays = [zeros(shape, ctx) for shape in arg_shapes] - aux_narrays = [zeros(shape, ctx) for shape in aux_shapes] - executor = self.bind(ctx, arg_narrays, grad_narrays, grad_req, aux_narrays) + grad_ndarrays = [zeros(shape, ctx) for shape in arg_shapes] + aux_ndarrays = [zeros(shape, ctx) for shape in aux_shapes] + executor = self.bind(ctx, arg_ndarrays, grad_ndarrays, grad_req, aux_ndarrays) return executor def bind(self, ctx, args, args_grad=None, grad_req='write', aux_states=None): @@ -382,35 +384,35 @@ def bind(self, ctx, args, args_grad=None, grad_req='write', aux_states=None): ctx : Context The device context the generated executor to run on. - args : list of NArray or dict of str->NArray + args : list of NDArray or dict of str->NDArray Input arguments to the symbol. - - If type is list of NArray, the position is in the same order of list_arguments. - - If type is dict of str->NArray, then it maps the name of arguments - to the corresponding NArray, + - If type is list of NDArray, the position is in the same order of list_arguments. + - If type is dict of str->NDArray, then it maps the name of arguments + to the corresponding NDArray, - In either case, all the arguments must be provided. - args_grad : list of NArray or dict of str->NArray, optional - When specified, args_grad provide NArrays to hold + args_grad : list of NDArray or dict of str->NDArray, optional + When specified, args_grad provide NDArrays to hold the result of gradient value in backward. - - If type is list of NArray, the position is in the same order of list_arguments. - - If type is dict of str->NArray, then it maps the name of arguments - to the corresponding NArray. - - When the type is dict of str->NArray, users only need to provide the dict + - If type is list of NDArray, the position is in the same order of list_arguments. + - If type is dict of str->NDArray, then it maps the name of arguments + to the corresponding NDArray. + - When the type is dict of str->NDArray, users only need to provide the dict for needed argument gradient. Only the specified argument gradient will be calculated. grad_req : {'write', 'add', 'null'}, or list of str or dict of str->str, optional Specifies how we should update the gradient to the args_grad. - - 'write' means everytime gradient is write to specified args_grad NArray. - - 'add' means everytime gradient is add to the specified NArray. + - 'write' means everytime gradient is write to specified args_grad NDArray. + - 'add' means everytime gradient is add to the specified NDArray. - 'null' means no action is taken, the gradient may not be calculated. - aux_states : list of NArray, or dict of str->NArray, optional + aux_states : list of NDArray, or dict of str->NDArray, optional Input auxiliary states to the symbol, only need to specify when list_auxiliary_states is not empty. - - If type is list of NArray, the position is in the same order of list_auxiliary_states - - If type is dict of str->NArray, then it maps the name of auxiliary_states - to the corresponding NArray, + - If type is list of NDArray, the position is in the same order of list_auxiliary_states + - If type is dict of str->NDArray, then it maps the name of auxiliary_states + to the corresponding NDArray, - In either case, all the auxiliary_states need to be provided. Returns @@ -432,17 +434,17 @@ def bind(self, ctx, args, args_grad=None, grad_req='write', aux_states=None): if not isinstance(ctx, Context): raise TypeError("Context type error") - args_handle = self._get_narray_handle('args', args, self.list_arguments(), False) + args_handle = self._get_ndarray_handle('args', args, self.list_arguments(), False) # setup args gradient if args_grad is None: - args_grad_handle = c_array(NArrayHandle, [None] * len(args)) + args_grad_handle = c_array(NDArrayHandle, [None] * len(args)) else: - args_grad_handle = self._get_narray_handle('args_grad', args_grad, + args_grad_handle = self._get_ndarray_handle('args_grad', args_grad, self.list_arguments(), True) if aux_states is None: aux_states = [] - aux_args_handle = self._get_narray_handle('aux_states', aux_states, + aux_args_handle = self._get_ndarray_handle('aux_states', aux_states, self.list_auxiliary_states(), False) # setup requirements @@ -474,8 +476,8 @@ def bind(self, ctx, args, args_grad=None, grad_req='write', aux_states=None): aux_args_handle, ctypes.byref(handle))) executor = Executor(handle) - executor.arg_narrays = args - executor.grad_narrays = args_grad + executor.arg_ndarrays = args + executor.grad_ndarrays = args_grad executor.auxiliary_states = aux_states return executor diff --git a/src/c_api.cc b/src/c_api.cc index 7ce3dac7cf61..8fedf170b3f2 100644 --- a/src/c_api.cc +++ b/src/c_api.cc @@ -8,7 +8,7 @@ #include #include #include -#include +#include #include #include #include @@ -181,101 +181,91 @@ inline int MXAPIGetFunctionRegInfo(const FunRegType *e, } // NOTE: return value is added in API_END -int MXNArrayCreateNone(NArrayHandle *out) { +int MXNDArrayCreateNone(NDArrayHandle *out) { API_BEGIN(); - *out = new NArray(); + *out = new NDArray(); API_END(); } -int MXNArrayCreateShareMem(mx_float *data, - mx_uint *shape, - mx_uint ndim, - NArrayHandle *out) { +int MXNDArrayCreate(const mx_uint *shape, + mx_uint ndim, + int dev_mask, + int dev_id, + int delay_alloc, + NDArrayHandle *out) { API_BEGIN(); - *out = new NArray(TBlob(data, TShape(shape, shape + ndim), - cpu::kDevMask), 0); - API_END(); -} - -int MXNArrayCreate(const mx_uint *shape, - mx_uint ndim, - int dev_mask, - int dev_id, - int delay_alloc, - NArrayHandle *out) { - API_BEGIN(); - *out = new NArray(TShape(shape, shape + ndim), + *out = new NDArray(TShape(shape, shape + ndim), Context(dev_mask, dev_id), delay_alloc != 0); API_END(); } -int MXNArrayLoadFromRawBytes(const void *buf, - mx_ulong size, - NArrayHandle *out) { - NArray *ptr = nullptr; +int MXNDArrayLoadFromRawBytes(const void *buf, + mx_ulong size, + NDArrayHandle *out) { + NDArray *ptr = nullptr; API_BEGIN(); dmlc::MemoryFixedSizeStream strm((void*)buf, size); // NOLINT(*) - ptr = new NArray(); + ptr = new NDArray(); if (!ptr->Load(&strm)) { - throw dmlc::Error("Invalid NArray serialization format"); + throw dmlc::Error("Invalid NDArray serialization format"); } *out = ptr; API_END_HANDLE_ERROR(delete ptr); } -int MXNArraySaveRawBytes(NArrayHandle handle, - mx_ulong *out_size, - const char **out_buf) { +int MXNDArraySaveRawBytes(NDArrayHandle handle, + mx_ulong *out_size, + const char **out_buf) { MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get(); API_BEGIN(); ret->ret_str.resize(0); dmlc::MemoryStringStream strm(&ret->ret_str); - static_cast(handle)->Save(&strm); + static_cast(handle)->Save(&strm); *out_size = ret->ret_str.length(); *out_buf = ret->ret_str.c_str(); API_END(); } -int MXNArraySyncCopyFromCPU(NArrayHandle handle, - const mx_float *data, - size_t size) { +int MXNDArraySyncCopyFromCPU(NDArrayHandle handle, + const mx_float *data, + size_t size) { API_BEGIN(); - static_cast(handle)->SyncCopyFromCPU(data, size); + static_cast(handle)->SyncCopyFromCPU(data, size); API_END(); } -int MXNArraySyncCopyToCPU(NArrayHandle handle, - mx_float *data, - size_t size) { +int MXNDArraySyncCopyToCPU(NDArrayHandle handle, + mx_float *data, + size_t size) { API_BEGIN(); - static_cast(handle)->SyncCopyToCPU(data, size); + static_cast(handle)->SyncCopyToCPU(data, size); API_END(); } -int MXNArrayWaitToRead(NArrayHandle handle) { +int MXNDArrayWaitToRead(NDArrayHandle handle) { API_BEGIN(); - static_cast(handle)->WaitToRead(); + static_cast(handle)->WaitToRead(); API_END(); } -int MXNArrayWaitToWrite(NArrayHandle handle) { +int MXNDArrayWaitToWrite(NDArrayHandle handle) { API_BEGIN(); - static_cast(handle)->WaitToWrite(); + static_cast(handle)->WaitToWrite(); API_END(); } -const uint64_t kMXAPINArrayListMagic = 0x112; +const uint64_t kMXAPINDArrayListMagic = 0x112; -int MXNArrayListSave(const char* fname, - mx_uint num_args, - NArrayHandle* args, - const char** keys) { +int MXNDArrayListSave(const char* fname, + mx_uint num_args, + NDArrayHandle* args, + const char** keys) { API_BEGIN(); - std::vector data(num_args); + std::vector data(num_args); std::vector names; for (mx_uint i = 0; i < num_args; ++i) { - data[i] = *static_cast(args[i]); + data[i] = *static_cast(args[i]); } if (keys != nullptr) { names.resize(num_args); @@ -284,7 +274,7 @@ int MXNArrayListSave(const char* fname, } } std::unique_ptr fo(dmlc::Stream::Create(fname, "w")); - uint64_t header = kMXAPINArrayListMagic, reserved = 0; + uint64_t header = kMXAPINDArrayListMagic, reserved = 0; fo->Write(&header, sizeof(header)); fo->Write(&reserved, sizeof(reserved)); fo->Write(data); @@ -292,33 +282,33 @@ int MXNArrayListSave(const char* fname, API_END(); } -int MXNArrayListLoad(const char* fname, - mx_uint *out_size, - NArrayHandle** out_arr, - mx_uint *out_name_size, - const char*** out_names) { +int MXNDArrayListLoad(const char* fname, + mx_uint *out_size, + NDArrayHandle** out_arr, + mx_uint *out_name_size, + const char*** out_names) { MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get(); ret->ret_vec_str.clear(); API_BEGIN(); - std::vector data; + std::vector data; std::vector &names = ret->ret_vec_str; std::unique_ptr fi(dmlc::Stream::Create(fname, "r")); uint64_t header, reserved; CHECK(fi->Read(&header)) - << "Invalid NArray file format"; + << "Invalid NDArray file format"; CHECK(fi->Read(&reserved)) - << "Invalid NArray file format"; - CHECK(header == kMXAPINArrayListMagic) - << "Invalid NArray file format"; + << "Invalid NDArray file format"; + CHECK(header == kMXAPINDArrayListMagic) + << "Invalid NDArray file format"; CHECK(fi->Read(&data)) - << "Invalid NArray file format"; + << "Invalid NDArray file format"; CHECK(fi->Read(&names)) - << "Invalid NArray file format"; + << "Invalid NDArray file format"; CHECK(names.size() == 0 || names.size() == data.size()) - << "Invalid NArray file format"; + << "Invalid NDArray file format"; ret->ret_handles.resize(data.size()); for (size_t i = 0; i < data.size(); ++i) { - NArray *ptr = new NArray(); + NDArray *ptr = new NDArray(); *ptr = data[i]; ret->ret_handles[i] = ptr; } @@ -333,23 +323,23 @@ int MXNArrayListLoad(const char* fname, API_END(); } -int MXNArrayWaitAll() { +int MXNDArrayWaitAll() { API_BEGIN(); Engine::Get()->WaitForAll(); API_END(); } -int MXNArrayFree(NArrayHandle handle) { +int MXNDArrayFree(NDArrayHandle handle) { API_BEGIN(); - delete static_cast(handle); + delete static_cast(handle); API_END(); } -int MXNArrayGetShape(NArrayHandle handle, - mx_uint *out_dim, - const mx_uint **out_pdata) { +int MXNDArrayGetShape(NDArrayHandle handle, + mx_uint *out_dim, + const mx_uint **out_pdata) { API_BEGIN(); - NArray *arr = static_cast(handle); + NDArray *arr = static_cast(handle); if (!arr->is_none()) { const TShape &s = arr->shape(); *out_dim = s.ndim(); @@ -360,13 +350,13 @@ int MXNArrayGetShape(NArrayHandle handle, API_END(); } -int MXNArrayGetData(NArrayHandle handle, - mx_float **out_pdata) { +int MXNDArrayGetData(NDArrayHandle handle, + mx_float **out_pdata) { API_BEGIN(); - NArray *arr = static_cast(handle); + NDArray *arr = static_cast(handle); if (!arr->is_none()) { CHECK(arr->ctx().dev_mask == cpu::kDevMask) - << "MXNArrayGetData can only be called for NArray on CPU"; + << "MXNDArrayGetData can only be called for NDArray on CPU"; const TBlob &b = arr->data(); CHECK(b.CheckContiguous()); *out_pdata = b.FlatTo2D().dptr_; @@ -376,11 +366,11 @@ int MXNArrayGetData(NArrayHandle handle, API_END(); } -int MXNArrayGetContext(NArrayHandle handle, - int *out_dev_mask, - int *out_dev_id) { +int MXNDArrayGetContext(NDArrayHandle handle, + int *out_dev_mask, + int *out_dev_id) { API_BEGIN(); - NArray *arr = static_cast(handle); + NDArray *arr = static_cast(handle); if (!arr->is_none()) { const Context &ctx = arr->ctx(); *out_dev_mask = ctx.dev_mask; @@ -395,7 +385,7 @@ int MXNArrayGetContext(NArrayHandle handle, int MXListFunctions(mx_uint *out_size, FunctionHandle **out_array) { API_BEGIN(); - auto &vec = dmlc::Registry::List(); + auto &vec = dmlc::Registry::List(); *out_size = static_cast(vec.size()); *out_array = (FunctionHandle*)(dmlc::BeginPtr(vec)); // NOLINT(*) API_END(); @@ -404,7 +394,7 @@ int MXListFunctions(mx_uint *out_size, int MXGetFunction(const char *name, FunctionHandle *out) { API_BEGIN(); - *out = dmlc::Registry::Find(name); + *out = dmlc::Registry::Find(name); API_END(); } @@ -415,7 +405,7 @@ int MXFuncGetInfo(FunctionHandle fun, const char ***arg_names, const char ***arg_type_infos, const char ***arg_descriptions) { - return MXAPIGetFunctionRegInfo(static_cast(fun), + return MXAPIGetFunctionRegInfo(static_cast(fun), name, description, num_args, arg_names, arg_type_infos, arg_descriptions); } @@ -426,7 +416,7 @@ int MXFuncDescribe(FunctionHandle fun, mx_uint *num_mutate_vars, int *type_mask) { API_BEGIN(); - auto *f = static_cast(fun); + auto *f = static_cast(fun); *num_use_vars = f->num_use_vars; *num_scalars = f->num_scalars; *num_mutate_vars = f->num_mutate_vars; @@ -435,14 +425,14 @@ int MXFuncDescribe(FunctionHandle fun, } int MXFuncInvoke(FunctionHandle fun, - NArrayHandle *use_vars, + NDArrayHandle *use_vars, mx_float *scalar_args, - NArrayHandle *mutate_vars) { + NDArrayHandle *mutate_vars) { API_BEGIN(); - auto *f = static_cast(fun); - f->body((NArray**)(use_vars), // NOLINT(*) + auto *f = static_cast(fun); + f->body((NDArray**)(use_vars), // NOLINT(*) scalar_args, - (NArray**)(mutate_vars)); // NOLINT(*) + (NDArray**)(mutate_vars)); // NOLINT(*) API_END(); } @@ -703,28 +693,28 @@ int MXExecutorForward(ExecutorHandle handle, bool is_train) { int MXExecutorBackward(ExecutorHandle handle, mx_uint len, - NArrayHandle *head_grads) { + NDArrayHandle *head_grads) { API_BEGIN(); Executor *exec = static_cast(handle); - std::vector narrays; - NArray **args_ptr = reinterpret_cast(head_grads); + std::vector ndarrays; + NDArray **args_ptr = reinterpret_cast(head_grads); for (mx_uint i = 0; i < len; ++i) { - narrays.push_back(*args_ptr[i]); + ndarrays.push_back(*args_ptr[i]); } - exec->Backward(narrays); + exec->Backward(ndarrays); API_END(); } int MXExecutorHeads(ExecutorHandle handle, mx_uint *out_size, - NArrayHandle **out) { + NDArrayHandle **out) { MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get(); API_BEGIN(); Executor *exec = static_cast(handle); - std::vector heads = exec->heads(); + std::vector heads = exec->heads(); ret->ret_handles.resize(heads.size()); for (size_t i = 0; i < heads.size(); ++i) { - NArray *ptr = new NArray(); + NDArray *ptr = new NDArray(); *ptr = heads[i]; ret->ret_handles[i] = ptr; } @@ -737,26 +727,26 @@ int MXExecutorBind(SymbolHandle symbol_handle, int dev_mask, int dev_id, mx_uint len, - NArrayHandle *in_args, - NArrayHandle *arg_grad_store, + NDArrayHandle *in_args, + NDArrayHandle *arg_grad_store, mx_uint *grad_req_type, mx_uint aux_states_len, - NArrayHandle *aux_states, + NDArrayHandle *aux_states, ExecutorHandle *out) { API_BEGIN(); Symbol *symb = static_cast(symbol_handle); Context ctx = Context(dev_mask, dev_id); - NArray **in_args_ptr = reinterpret_cast(in_args); - NArray **arg_grad_ptr = reinterpret_cast(arg_grad_store); - NArray **aux_states_ptr = reinterpret_cast(aux_states); - std::vector in_args_vec; - std::vector arg_grad_vec; + NDArray **in_args_ptr = reinterpret_cast(in_args); + NDArray **arg_grad_ptr = reinterpret_cast(arg_grad_store); + NDArray **aux_states_ptr = reinterpret_cast(aux_states); + std::vector in_args_vec; + std::vector arg_grad_vec; std::vector grad_req_vec; - std::vector aux_states_vec; + std::vector aux_states_vec; for (mx_uint i = 0; i < len; ++i) { in_args_vec.push_back(*(in_args_ptr[i])); if (arg_grad_ptr[i] == nullptr) { - arg_grad_vec.push_back(NArray()); + arg_grad_vec.push_back(NDArray()); grad_req_vec.push_back(kNullOp); } else { arg_grad_vec.push_back(*(arg_grad_ptr[i])); @@ -831,51 +821,51 @@ int MXDataIterNext(DataIterHandle handle, int *out) { API_END(); } -int MXDataIterGetLabel(DataIterHandle handle, NArrayHandle *out) { +int MXDataIterGetLabel(DataIterHandle handle, NDArrayHandle *out) { API_BEGIN(); DataBatch db = static_cast* >(handle)->Value(); - *out = new NArray(db.data[1], 0); + *out = new NDArray(db.data[1], 0); API_END(); } -int MXDataIterGetData(DataIterHandle handle, NArrayHandle *out) { +int MXDataIterGetData(DataIterHandle handle, NDArrayHandle *out) { API_BEGIN(); DataBatch db = static_cast* >(handle)->Value(); - *out = new NArray(db.data[0], 0); + *out = new NDArray(db.data[0], 0); API_END(); } -int MXKVStoreInit(int num, int* keys, NArrayHandle* vals) { +int MXKVStoreInit(int num, int* keys, NDArrayHandle* vals) { API_BEGIN(); std::vector v_keys(num); - std::vector v_vals(num); + std::vector v_vals(num); for (int i = 0; i < num; ++i) { v_keys[i] = keys[i]; - v_vals[i] = *static_cast(vals[i]); + v_vals[i] = *static_cast(vals[i]); } KVStore::Get()->Init(v_keys, v_vals); API_END(); } -int MXKVStorePush(int num, int* keys, NArrayHandle* vals) { +int MXKVStorePush(int num, int* keys, NDArrayHandle* vals) { API_BEGIN(); std::vector v_keys(num); - std::vector v_vals(num); + std::vector v_vals(num); for (int i = 0; i < num; ++i) { v_keys[i] = keys[i]; - v_vals[i] = *static_cast(vals[i]); + v_vals[i] = *static_cast(vals[i]); } KVStore::Get()->Push(v_keys, v_vals); API_END(); } -int MXKVStorePull(int num, int* keys, NArrayHandle* vals) { +int MXKVStorePull(int num, int* keys, NDArrayHandle* vals) { API_BEGIN(); std::vector v_keys(num); - std::vector v_vals(num); + std::vector v_vals(num); for (int i = 0; i < num; ++i) { v_keys[i] = keys[i]; - v_vals[i] = static_cast(vals[i]); + v_vals[i] = static_cast(vals[i]); } KVStore::Get()->Pull(v_keys, v_vals); API_END(); @@ -895,10 +885,10 @@ int MXKVStoreStop() { int MXKVStoreSetUpdater(MXKVStoreUpdater updater) { API_BEGIN(); - auto updt = [updater](int key, const NArray& recv, NArray* local) { - NArray* recv_copy = new NArray(); + auto updt = [updater](int key, const NDArray& recv, NDArray* local) { + NDArray* recv_copy = new NDArray(); *recv_copy = recv; - NArray* local_copy = new NArray(); + NDArray* local_copy = new NDArray(); *local_copy = *local; updater(key, recv_copy, local_copy); }; diff --git a/src/kvstore/kvstore_local.h b/src/kvstore/kvstore_local.h index bea6cb019356..e781fc35da70 100644 --- a/src/kvstore/kvstore_local.h +++ b/src/kvstore/kvstore_local.h @@ -45,7 +45,7 @@ class KVStoreLocal : public KVStore { virtual int get_group_size() const { return 1; } virtual void Init(const std::vector& keys, - const std::vector& values) { + const std::vector& values) { for (size_t i = 0; i < keys.size(); ++i) { CHECK(local_.find(keys[i]) == local_.end()) << "duplicate init of key " << keys[i]; @@ -54,9 +54,9 @@ class KVStoreLocal : public KVStore { } virtual void Push(const std::vector& keys, - const std::vector& values) { + const std::vector& values) { std::vector uniq_keys; - std::vector > grouped_vals; + std::vector > grouped_vals; GroupKVPairs(keys, values, &uniq_keys, &grouped_vals); CHECK(updater_) << "invalid updater"; @@ -69,16 +69,16 @@ class KVStoreLocal : public KVStore { } virtual void Pull(const std::vector& keys, - const std::vector& values) { + const std::vector& values) { std::vector uniq_keys; - std::vector > grouped_vals; + std::vector > grouped_vals; GroupKVPairs(keys, values, &uniq_keys, &grouped_vals); for (size_t i = 0; i < uniq_keys.size(); ++i) { int key = uniq_keys[i]; auto it = local_.find(key); CHECK(it != local_.end()) << "key " << key << " has not been inited"; - for (NArray* v : grouped_vals[i]) + for (NDArray* v : grouped_vals[i]) CopyFromTo(it->second, v); } } @@ -118,7 +118,7 @@ class KVStoreLocal : public KVStore { /** * \brief returns the aggregated push value */ - NArray MergePushValue(int key, const std::vector& val) { + NDArray MergePushValue(int key, const std::vector& val) { CHECK(val.size()); auto& buf = merge_buf_[key]; if (buf.merged.is_none()) { @@ -138,7 +138,7 @@ class KVStoreLocal : public KVStore { buf.gpu_buf.resize(id + 2); } if (buf.gpu_buf[id].is_none()) { - buf.gpu_buf[id] = NArray(v.shape(), pinned_ctx_); + buf.gpu_buf[id] = NDArray(v.shape(), pinned_ctx_); } CopyFromTo(v, &buf.gpu_buf[id]); buf.merged += buf.gpu_buf[id]; @@ -157,16 +157,16 @@ class KVStoreLocal : public KVStore { /// \brief temperal space for pushing value struct MergeBuf { /// \brief the cpu buffer for gpu data - std::vector gpu_buf; + std::vector gpu_buf; /// \brief merged data in cpu - NArray merged; + NDArray merged; }; /// \brief buffer for merging push value std::unordered_map merge_buf_; /// \brief local storage - std::unordered_map local_; + std::unordered_map local_; Context pinned_ctx_; diff --git a/src/narray/narray.cc b/src/ndarray/ndarray.cc similarity index 64% rename from src/narray/narray.cc rename to src/ndarray/ndarray.cc index 661e2004079a..8360937f084d 100644 --- a/src/narray/narray.cc +++ b/src/ndarray/ndarray.cc @@ -11,7 +11,7 @@ #include "./narray_function.h" namespace dmlc { -DMLC_REGISTRY_ENABLE(::mxnet::NArrayFunctionReg); +DMLC_REGISTRY_ENABLE(::mxnet::NDArrayFunctionReg); } // namespace dmlc namespace mxnet { @@ -19,19 +19,19 @@ namespace mxnet { * \brief run a binary operation * \param lhs left operand * \param rhs right operand - * \param out the output narray + * \param out the output ndarray * \param binary_op the real */ template -inline void BinaryOp(const NArray &lhs, - const NArray &rhs, - NArray *out) { +inline void BinaryOp(const NDArray &lhs, + const NDArray &rhs, + NDArray *out) { // no check if both of them are on cpu if (lhs.ctx().dev_mask != cpu::kDevMask || rhs.ctx().dev_mask != cpu::kDevMask) CHECK(lhs.ctx() == rhs.ctx()) << "operands context mismatch"; // if out is none, allocate space if (out->is_none()) { - *out = NArray(OP::GetShape(lhs.shape(), rhs.shape()), lhs.ctx(), true); + *out = NDArray(OP::GetShape(lhs.shape(), rhs.shape()), lhs.ctx(), true); } else { // no check if both of them are on cpu if (lhs.ctx().dev_mask != cpu::kDevMask || @@ -42,7 +42,7 @@ inline void BinaryOp(const NArray &lhs, << "target shape mismatch"; } // important: callback must always capture by value - NArray ret = *out; + NDArray ret = *out; // get the const variables std::vector const_vars; if (lhs.ptr_->var != ret.ptr_->var) const_vars.push_back(lhs.ptr_->var); @@ -54,7 +54,7 @@ inline void BinaryOp(const NArray &lhs, Engine::Get()->PushSync([lhs, rhs, ret](RunContext ctx) { ret.ptr_->CheckAndAlloc(); TBlob tmp = ret.data(); - narray::Eval(lhs.data(), rhs.data(), &tmp, ctx); + ndarray::Eval(lhs.data(), rhs.data(), &tmp, ctx); }, lhs.ctx(), const_vars, {ret.ptr_->var}); break; } @@ -63,7 +63,7 @@ inline void BinaryOp(const NArray &lhs, Engine::Get()->PushSync([lhs, rhs, ret](RunContext ctx) { ret.ptr_->CheckAndAlloc(); TBlob tmp = ret.data(); - narray::Eval(lhs.data(), rhs.data(), &tmp, ctx); + ndarray::Eval(lhs.data(), rhs.data(), &tmp, ctx); // Wait GPU kernel to complete ctx.get_stream()->Wait(); }, lhs.ctx(), const_vars, {ret.ptr_->var}); @@ -74,16 +74,16 @@ inline void BinaryOp(const NArray &lhs, } } -inline void SetValueOp(const real_t &rhs, NArray *out) { +inline void SetValueOp(const real_t &rhs, NDArray *out) { CHECK_NE(out->is_none(), true) << "Set value target must not be empty"; // important: callback must always capture by value - NArray ret = *out; + NDArray ret = *out; switch (ret.ctx().dev_mask) { case cpu::kDevMask: { Engine::Get()->PushSync([rhs, ret](RunContext ctx) { ret.ptr_->CheckAndAlloc(); TBlob tmp = ret.data(); - narray::Eval(rhs, &tmp, ctx); + ndarray::Eval(rhs, &tmp, ctx); }, ret.ctx(), {}, {ret.ptr_->var}); break; } @@ -92,7 +92,7 @@ inline void SetValueOp(const real_t &rhs, NArray *out) { Engine::Get()->PushSync([rhs, ret](RunContext ctx) { ret.ptr_->CheckAndAlloc(); TBlob tmp = ret.data(); - narray::Eval(rhs, &tmp, ctx); + ndarray::Eval(rhs, &tmp, ctx); // Wait GPU kernel to complete ctx.get_stream()->Wait(); }, ret.ctx(), {}, {ret.ptr_->var}); @@ -106,21 +106,21 @@ inline void SetValueOp(const real_t &rhs, NArray *out) { * \brief run a binary operation * \param lhs left operand * \param rhs right operand - * \param out the output narray + * \param out the output ndarray * \param binary_op the real */ template -inline void ScalarOp(const NArray &lhs, +inline void ScalarOp(const NDArray &lhs, const real_t &rhs, - NArray *out) { + NDArray *out) { if (out->is_none()) { - *out = NArray(lhs.shape(), lhs.ctx(), true); + *out = NDArray(lhs.shape(), lhs.ctx(), true); } else { CHECK(out->ctx() == lhs.ctx()) << "target context mismatch"; CHECK(out->shape() == lhs.shape()) << "target shape mismatch"; } // important: callback must always capture by value - NArray ret = *out; + NDArray ret = *out; // get the const variables std::vector const_vars; if (lhs.ptr_->var != ret.ptr_->var) const_vars.push_back(lhs.ptr_->var); @@ -131,7 +131,7 @@ inline void ScalarOp(const NArray &lhs, Engine::Get()->PushSync([lhs, rhs, ret](RunContext ctx) { ret.ptr_->CheckAndAlloc(); TBlob tmp = ret.data(); - narray::Eval(lhs.data(), rhs, &tmp, ctx); + ndarray::Eval(lhs.data(), rhs, &tmp, ctx); }, lhs.ctx(), const_vars, {ret.ptr_->var}); break; } @@ -140,7 +140,7 @@ inline void ScalarOp(const NArray &lhs, Engine::Get()->PushSync([lhs, rhs, ret](RunContext ctx) { ret.ptr_->CheckAndAlloc(); TBlob tmp = ret.data(); - narray::Eval(lhs.data(), rhs, &tmp, ctx); + ndarray::Eval(lhs.data(), rhs, &tmp, ctx); // Wait GPU kernel to complete ctx.get_stream()->Wait(); }, lhs.ctx(), const_vars, {ret.ptr_->var}); @@ -151,13 +151,13 @@ inline void ScalarOp(const NArray &lhs, } } -void CopyFromTo(const NArray &from, NArray *to) { +void CopyFromTo(const NDArray &from, NDArray *to) { CHECK(from.shape() == to->shape()) << "operands shape mismatch"; CHECK(from.shape().ndim() != 0) << "source operands have zero dimension shape"; // important: callback must always capture by value - NArray ret = *to; + NDArray ret = *to; int a = from.ctx().dev_mask; int b = to->ctx().dev_mask; @@ -168,7 +168,7 @@ void CopyFromTo(const NArray &from, NArray *to) { Engine::Get()->PushSync([from, ret](RunContext ctx) { ret.ptr_->CheckAndAlloc(); TBlob tmp = ret.data(); - narray::Copy(from.data(), &tmp, + ndarray::Copy(from.data(), &tmp, from.ctx(), ret.ctx(), ctx); }, from.ctx(), const_vars, {ret.ptr_->var}); } else { @@ -177,7 +177,7 @@ void CopyFromTo(const NArray &from, NArray *to) { Engine::Get()->PushSync([from, ret](RunContext ctx) { ret.ptr_->CheckAndAlloc(); TBlob tmp = ret.data(); - narray::Copy(from.data(), &tmp, + ndarray::Copy(from.data(), &tmp, from.ctx(), ret.ctx(), ctx); // Wait GPU kernel to complete ctx.get_stream()->Wait(); @@ -186,7 +186,7 @@ void CopyFromTo(const NArray &from, NArray *to) { Engine::Get()->PushSync([from, ret](RunContext ctx) { ret.ptr_->CheckAndAlloc(); TBlob tmp = ret.data(); - narray::Copy(from.data(), &tmp, + ndarray::Copy(from.data(), &tmp, from.ctx(), ret.ctx(), ctx); // Wait GPU kernel to complete ctx.get_stream()->Wait(); @@ -195,7 +195,7 @@ void CopyFromTo(const NArray &from, NArray *to) { Engine::Get()->PushSync([from, ret](RunContext ctx) { ret.ptr_->CheckAndAlloc(); TBlob tmp = ret.data(); - narray::Copy(from.data(), &tmp, + ndarray::Copy(from.data(), &tmp, from.ctx(), ret.ctx(), ctx); // Wait GPU kernel to complete ctx.get_stream()->Wait(); @@ -210,94 +210,94 @@ void CopyFromTo(const NArray &from, NArray *to) { } template -inline NArray BinaryOpRet(const NArray &lhs, - const NArray &rhs) { - NArray ret; +inline NDArray BinaryOpRet(const NDArray &lhs, + const NDArray &rhs) { + NDArray ret; BinaryOp(lhs, rhs, &ret); return ret; } template -inline NArray ScalarOpRet(const NArray &lhs, - const real_t &rhs) { - NArray ret; +inline NDArray ScalarOpRet(const NDArray &lhs, + const real_t &rhs) { + NDArray ret; ScalarOp(lhs, rhs, &ret); return ret; } template -inline NArray &BinaryOpApply(NArray *dst, - const NArray &src) { +inline NDArray &BinaryOpApply(NDArray *dst, + const NDArray &src) { BinaryOp(*dst, src, dst); return *dst; } template -inline NArray &ScalarOpApply(NArray *dst, +inline NDArray &ScalarOpApply(NDArray *dst, const real_t &src) { ScalarOp(*dst, src, dst); return *dst; } // Binary -NArray operator+(const NArray &lhs, const NArray &rhs) { - return BinaryOpRet(lhs, rhs); +NDArray operator+(const NDArray &lhs, const NDArray &rhs) { + return BinaryOpRet(lhs, rhs); } -NArray operator-(const NArray &lhs, const NArray &rhs) { - return BinaryOpRet(lhs, rhs); +NDArray operator-(const NDArray &lhs, const NDArray &rhs) { + return BinaryOpRet(lhs, rhs); } -NArray operator*(const NArray &lhs, const NArray &rhs) { - return BinaryOpRet(lhs, rhs); +NDArray operator*(const NDArray &lhs, const NDArray &rhs) { + return BinaryOpRet(lhs, rhs); } -NArray operator/(const NArray &lhs, const NArray &rhs) { - return BinaryOpRet(lhs, rhs); +NDArray operator/(const NDArray &lhs, const NDArray &rhs) { + return BinaryOpRet(lhs, rhs); } // Scalar -NArray operator+(const NArray &lhs, const real_t &rhs) { - return ScalarOpRet(lhs, rhs); +NDArray operator+(const NDArray &lhs, const real_t &rhs) { + return ScalarOpRet(lhs, rhs); } -NArray operator-(const NArray &lhs, const real_t &rhs) { - return ScalarOpRet(lhs, rhs); +NDArray operator-(const NDArray &lhs, const real_t &rhs) { + return ScalarOpRet(lhs, rhs); } -NArray operator*(const NArray &lhs, const real_t &rhs) { - return ScalarOpRet(lhs, rhs); +NDArray operator*(const NDArray &lhs, const real_t &rhs) { + return ScalarOpRet(lhs, rhs); } -NArray operator/(const NArray &lhs, const real_t &rhs) { - return ScalarOpRet(lhs, rhs); +NDArray operator/(const NDArray &lhs, const real_t &rhs) { + return ScalarOpRet(lhs, rhs); } // Binary -NArray &NArray::operator=(real_t scalar) { +NDArray &NDArray::operator=(real_t scalar) { SetValueOp(scalar, this); return *this; } -NArray &NArray::operator+=(const NArray &src) { - return BinaryOpApply(this, src); +NDArray &NDArray::operator+=(const NDArray &src) { + return BinaryOpApply(this, src); } -NArray &NArray::operator-=(const NArray &src) { - return BinaryOpApply(this, src); +NDArray &NDArray::operator-=(const NDArray &src) { + return BinaryOpApply(this, src); } -NArray &NArray::operator*=(const NArray &src) { - return BinaryOpApply(this, src); +NDArray &NDArray::operator*=(const NDArray &src) { + return BinaryOpApply(this, src); } -NArray &NArray::operator/=(const NArray &src) { - return BinaryOpApply(this, src); +NDArray &NDArray::operator/=(const NDArray &src) { + return BinaryOpApply(this, src); } // Scalar -NArray &NArray::operator+=(const real_t &src) { - return ScalarOpApply(this, src); +NDArray &NDArray::operator+=(const real_t &src) { + return ScalarOpApply(this, src); } -NArray &NArray::operator-=(const real_t &src) { - return ScalarOpApply(this, src); +NDArray &NDArray::operator-=(const real_t &src) { + return ScalarOpApply(this, src); } -NArray &NArray::operator*=(const real_t &src) { - return ScalarOpApply(this, src); +NDArray &NDArray::operator*=(const real_t &src) { + return ScalarOpApply(this, src); } -NArray &NArray::operator/=(const real_t &src) { - return ScalarOpApply(this, src); +NDArray &NDArray::operator/=(const real_t &src) { + return ScalarOpApply(this, src); } -void NArray::Save(dmlc::Stream *strm) const { +void NDArray::Save(dmlc::Stream *strm) const { // save shape shape_.Save(strm); if (is_none()) return; @@ -305,7 +305,7 @@ void NArray::Save(dmlc::Stream *strm) const { Context ctx = this->ctx(); ctx.Save(strm); TBlob save_data; - NArray temp; + NDArray temp; if (ctx.dev_mask != cpu::kDevMask) { temp = this->Copy(Context(cpu::kDevMask, 0)); temp.WaitToRead(); @@ -317,7 +317,7 @@ void NArray::Save(dmlc::Stream *strm) const { // save type flag int32_t type_flag = save_data.type_flag_; CHECK(type_flag == mshadow::DataType::kFlag) - << "Only support float NArray so far"; + << "Only support float NDArray so far"; strm->Write(&type_flag, sizeof(type_flag)); CHECK(save_data.CheckContiguous()); // save data: need to change this after more type mask is supported @@ -325,12 +325,12 @@ void NArray::Save(dmlc::Stream *strm) const { strm->Write(save_data.dptr_, type_size * shape_.Size()); } -bool NArray::Load(dmlc::Stream *strm) { +bool NDArray::Load(dmlc::Stream *strm) { // load shape TShape shape; if (!shape.Load(strm)) return false; if (shape.ndim() == 0) { - *this = NArray(); return true; + *this = NDArray(); return true; } // load context Context ctx; @@ -339,9 +339,9 @@ bool NArray::Load(dmlc::Stream *strm) { int32_t type_flag; if (strm->Read(&type_flag, sizeof(type_flag)) != sizeof(type_flag)) return false; CHECK(type_flag == mshadow::DataType::kFlag) - << "Only support float NArray so far"; + << "Only support float NDArray so far"; // load data into CPUbu - NArray temp(shape, Context(cpu::kDevMask, ctx.dev_id)); + NDArray temp(shape, Context(cpu::kDevMask, ctx.dev_id)); TBlob load_data = temp.data(); size_t type_size = sizeof(real_t); size_t nread = type_size * shape.Size(); @@ -354,13 +354,13 @@ bool NArray::Load(dmlc::Stream *strm) { } } -NArray NArray::Copy(Context ctx) const { - NArray ret(shape(), ctx, true); +NDArray NDArray::Copy(Context ctx) const { + NDArray ret(shape(), ctx, true); CopyFromTo(*this, &ret); return ret; } -void NArray::SyncCopyFromCPU(const real_t *data, size_t size) const { +void NDArray::SyncCopyFromCPU(const real_t *data, size_t size) const { this->WaitToWrite(); TShape dshape = this->shape(); CHECK_EQ(dshape.Size(), size) @@ -371,7 +371,7 @@ void NArray::SyncCopyFromCPU(const real_t *data, size_t size) const { RunContext run_ctx; if (ctx.dev_mask == cpu::kDevMask) { - narray::Copy(src, &dst, Context(cpu::kDevMask, 0), ctx, run_ctx); + ndarray::Copy(src, &dst, Context(cpu::kDevMask, 0), ctx, run_ctx); } else { #if MXNET_USE_CUDA // use empty stream to do sync copy @@ -379,14 +379,14 @@ void NArray::SyncCopyFromCPU(const real_t *data, size_t size) const { // Maybe move to engine part mshadow::Stream zero_stream; run_ctx.stream = &zero_stream; - narray::Copy(src, &dst, Context(cpu::kDevMask, 0), ctx, run_ctx); + ndarray::Copy(src, &dst, Context(cpu::kDevMask, 0), ctx, run_ctx); #else LOG(FATAL) << "GPU is not enabled"; #endif } } -void NArray::SyncCopyToCPU(real_t *data, size_t size) const { +void NDArray::SyncCopyToCPU(real_t *data, size_t size) const { this->WaitToRead(); TShape dshape = this->shape(); CHECK_EQ(dshape.Size(), size) @@ -397,7 +397,7 @@ void NArray::SyncCopyToCPU(real_t *data, size_t size) const { RunContext run_ctx; if (ctx.dev_mask == cpu::kDevMask) { - narray::Copy(src, &dst, ctx, Context(cpu::kDevMask, 0), run_ctx); + ndarray::Copy(src, &dst, ctx, Context(cpu::kDevMask, 0), run_ctx); } else { #if MXNET_USE_CUDA // use empty stream to do sync copy @@ -405,7 +405,7 @@ void NArray::SyncCopyToCPU(real_t *data, size_t size) const { // Maybe move to engine part mshadow::Stream zero_stream; run_ctx.stream = &zero_stream; - narray::Copy(src, &dst, ctx, Context(cpu::kDevMask, 0), run_ctx); + ndarray::Copy(src, &dst, ctx, Context(cpu::kDevMask, 0), run_ctx); #else LOG(FATAL) << "GPU is not enabled"; #endif @@ -413,33 +413,33 @@ void NArray::SyncCopyToCPU(real_t *data, size_t size) const { } // register API function -// those with underscore will be registered at NArray -MXNET_REGISTER_NARRAY_FUN(_set_value).set_function(SetValueOp); +// those with underscore will be registered at NDArray +MXNET_REGISTER_NDARRAY_FUN(_set_value).set_function(SetValueOp); -MXNET_REGISTER_NARRAY_FUN(_plus).set_function(BinaryOp); -MXNET_REGISTER_NARRAY_FUN(_minus).set_function(BinaryOp); -MXNET_REGISTER_NARRAY_FUN(_mul).set_function(BinaryOp); -MXNET_REGISTER_NARRAY_FUN(_div).set_function(BinaryOp); +MXNET_REGISTER_NDARRAY_FUN(_plus).set_function(BinaryOp); +MXNET_REGISTER_NDARRAY_FUN(_minus).set_function(BinaryOp); +MXNET_REGISTER_NDARRAY_FUN(_mul).set_function(BinaryOp); +MXNET_REGISTER_NDARRAY_FUN(_div).set_function(BinaryOp); // register API function -// those with underscore will be registered at NArray +// those with underscore will be registered at NDArray // scalar -MXNET_REGISTER_NARRAY_FUN(_plus_scalar).set_function(ScalarOp); -MXNET_REGISTER_NARRAY_FUN(_minus_scalar).set_function(ScalarOp); -MXNET_REGISTER_NARRAY_FUN(_mul_scalar).set_function(ScalarOp); -MXNET_REGISTER_NARRAY_FUN(_div_scalar).set_function(ScalarOp); +MXNET_REGISTER_NDARRAY_FUN(_plus_scalar).set_function(ScalarOp); +MXNET_REGISTER_NDARRAY_FUN(_minus_scalar).set_function(ScalarOp); +MXNET_REGISTER_NDARRAY_FUN(_mul_scalar).set_function(ScalarOp); +MXNET_REGISTER_NDARRAY_FUN(_div_scalar).set_function(ScalarOp); // register API function -// those with underscore will be registered at NArray +// those with underscore will be registered at NDArray // scalar // reverse scalar -MXNET_REGISTER_NARRAY_FUN(_rminus_scalar).set_function(ScalarOp); -MXNET_REGISTER_NARRAY_FUN(_rdiv_scalar).set_function(ScalarOp); +MXNET_REGISTER_NDARRAY_FUN(_rminus_scalar).set_function(ScalarOp); +MXNET_REGISTER_NDARRAY_FUN(_rdiv_scalar).set_function(ScalarOp); // copy function is special // that we need to remove kAcceptEmptyMutateTarget from it -MXNET_REGISTER_NARRAY_FUN(_copyto) +MXNET_REGISTER_NDARRAY_FUN(_copyto) .set_function(CopyFromTo) -.set_type_mask(kNArrayArgBeforeScalar); +.set_type_mask(kNDArrayArgBeforeScalar); } // namespace mxnet diff --git a/src/narray/narray_function-inl.h b/src/ndarray/ndarray_function-inl.h similarity index 88% rename from src/narray/narray_function-inl.h rename to src/ndarray/ndarray_function-inl.h index 155a4e19c1b7..6494d64a148e 100644 --- a/src/narray/narray_function-inl.h +++ b/src/ndarray/ndarray_function-inl.h @@ -1,11 +1,11 @@ /*! * Copyright (c) 2015 by Contributors - * \file narray_function-inl.h - * \brief + * \file ndarray_function-inl.h + * \brief The real implementation of NDArray functions. */ -#ifndef MXNET_NARRAY_NARRAY_FUNCTION_INL_H_ -#define MXNET_NARRAY_NARRAY_FUNCTION_INL_H_ -#include "./narray_function.h" +#ifndef MXNET_NDARRAY_NDARRAY_FUNCTION_INL_H_ +#define MXNET_NDARRAY_NDARRAY_FUNCTION_INL_H_ +#include "./ndarray_function.h" // this file will be included twice by CPU and GPU // macro to help specialize evaluation function #ifndef DECL_BINARY @@ -17,19 +17,19 @@ #endif #ifndef DECL_SCALAR -#define DECL_SCALAR(XPU, OP, FUN, REVERSE) \ +#define DECL_SCALAR(XPU, OP, FUN, REVERSE) \ template<> \ void Eval(const TBlob &lhs, const real_t &rhs, TBlob *ret, RunContext ctx) { \ - FUN(lhs, rhs, ret, ctx); \ + FUN(lhs, rhs, ret, ctx); \ } #endif #ifndef DECL_SETVALUE -#define DECL_SETVALUE(XPU) \ +#define DECL_SETVALUE(XPU) \ template<> \ - void Eval(const real_t &rhs, TBlob *ret, RunContext ctx) { \ - mshadow::Stream *s = static_cast*>(ctx.stream); \ - ret->FlatTo2D(s) = rhs; \ + void Eval(const real_t &rhs, TBlob *ret, RunContext ctx) { \ + mshadow::Stream *s = static_cast*>(ctx.stream); \ + ret->FlatTo2D(s) = rhs; \ } #endif @@ -41,7 +41,7 @@ #endif namespace mxnet { -namespace narray { +namespace ndarray { // true implementation template inline void EvalBinary_(const TBlob &lhs, const TBlob &rhs, @@ -84,7 +84,7 @@ DECL_SCALAR(DEVICE, Mul, EvalScalar_, false) DECL_SCALAR(DEVICE, Div, EvalScalar_, false) // DECL_SETVALUE(DEVICE) -} // namespace narray +} // namespace ndarray } // namespace mxnet -#endif // MXNET_NARRAY_NARRAY_FUNCTION_INL_H_ +#endif // MXNET_NDARRAY_NDARRAY_FUNCTION_INL_H_ diff --git a/src/narray/narray_function.cc b/src/ndarray/ndarray_function.cc similarity index 73% rename from src/narray/narray_function.cc rename to src/ndarray/ndarray_function.cc index d67bb91a23aa..a7881907655a 100644 --- a/src/narray/narray_function.cc +++ b/src/ndarray/ndarray_function.cc @@ -1,15 +1,15 @@ /*! * Copyright (c) 2015 by Contributors - * \file narray_function_cpu.cc + * \file ndarray_function_cpu.cc * \brief */ // this will be invoked by gcc and compile CPU version -#include "./narray_function.h" -#include "./narray_function-inl.h" +#include "./ndarray_function.h" +#include "./ndarray_function-inl.h" namespace mxnet { -namespace narray { +namespace ndarray { template<> void Copy(const TBlob &from, TBlob *to, Context from_ctx, Context to_ctx, @@ -17,5 +17,5 @@ void Copy(const TBlob &from, TBlob *to, mshadow::Copy(to->FlatTo2D(), from.FlatTo2D()); } -} // namespace narray +} // namespace ndarray } // namespace mxnet diff --git a/src/narray/narray_function.cu b/src/ndarray/ndarray_function.cu similarity index 93% rename from src/narray/narray_function.cu rename to src/ndarray/ndarray_function.cu index f632b5dd65c3..3d17454c48ae 100644 --- a/src/narray/narray_function.cu +++ b/src/ndarray/ndarray_function.cu @@ -1,10 +1,10 @@ // this will be invoked by nvcc and compile GPU version #include -#include "./narray_function.h" -#include "./narray_function-inl.h" +#include "./ndarray_function.h" +#include "./ndarray_function-inl.h" namespace mxnet { -namespace narray { +namespace ndarray { template<> void Copy(const TBlob &from, TBlob *to, Context from_ctx, Context to_ctx, @@ -44,5 +44,5 @@ void Copy(const TBlob &from, TBlob *to, s->stream_); } } -} // namespace narray +} // namespace ndarray } // namespace mxnet diff --git a/src/narray/narray_function.h b/src/ndarray/ndarray_function.h similarity index 80% rename from src/narray/narray_function.h rename to src/ndarray/ndarray_function.h index dc879c28c1e8..94b03ab05a6f 100644 --- a/src/narray/narray_function.h +++ b/src/ndarray/ndarray_function.h @@ -1,18 +1,18 @@ /*! * Copyright (c) 2015 by Contributors - * \file narray_op.h - * \brief the real execution functions of narray operations + * \file ndarray_op.h + * \brief the real execution functions of ndarray operations */ -#ifndef MXNET_NARRAY_NARRAY_FUNCTION_H_ -#define MXNET_NARRAY_NARRAY_FUNCTION_H_ +#ifndef MXNET_NDARRAY_NDARRAY_FUNCTION_H_ +#define MXNET_NDARRAY_NDARRAY_FUNCTION_H_ #include #include #include #include namespace mxnet { -/*! \brief namespace to support all possible NArray operator */ -namespace narray { +/*! \brief namespace to support all possible Ndarray operator */ +namespace ndarray { struct BinaryBase { inline static TShape GetShape(const TShape &lshape, const TShape &rshape) { CHECK(lshape == rshape) << "operands shape mismatch"; @@ -48,6 +48,6 @@ void Copy(const TBlob &from, TBlob *to, Context from_ctx, Context to_ctx, RunContext ctx); -} // namespace narray +} // namespace ndarray } // namespace mxnet -#endif // MXNET_NARRAY_NARRAY_FUNCTION_H_ +#endif // MXNET_NDARRAY_NDARRAY_FUNCTION_H_ diff --git a/src/symbol/graph_executor.cc b/src/symbol/graph_executor.cc index 2f0bd318cf67..7839f01164a8 100644 --- a/src/symbol/graph_executor.cc +++ b/src/symbol/graph_executor.cc @@ -251,10 +251,10 @@ void GraphExecutor::InitGraph(Symbol symbol, Context ctx, bool need_backward) { } } -void GraphExecutor::InitDataEntryInfo(const std::vector &in_args, - const std::vector &arg_grad_store, +void GraphExecutor::InitDataEntryInfo(const std::vector &in_args, + const std::vector &arg_grad_store, const std::vector &grad_req_type, - const std::vector &aux_states) { + const std::vector &aux_states) { CHECK_EQ(arg_grad_store.size(), grad_req_type.size()); CHECK_EQ(in_args.size(), graph_.arg_nodes.size()); // bind inputs @@ -321,18 +321,18 @@ void GraphExecutor::InitDataEntryInfo(const std::vector &in_args, } } // bind aux args - size_t aux_narray_idx = 0; + size_t aux_ndarray_idx = 0; for (size_t i = 0; i < aux_shapes.size(); ++i) { op_nodes_[i].aux_states.resize(aux_shapes[i].size()); for (size_t j = 0; j < aux_shapes[i].size(); ++j) { DataEntryInfo &info = op_nodes_[i].aux_states[j]; info.shape = aux_shapes[i][j]; info.type = kBindByExternal; - CHECK_GT(aux_states.size(), aux_narray_idx) - << "Input auxiliary NArray is less than required"; - info.data = aux_states[aux_narray_idx++]; + CHECK_GT(aux_states.size(), aux_ndarray_idx) + << "Input auxiliary NDArray is less than required"; + info.data = aux_states[aux_ndarray_idx++]; CHECK_EQ(info.data.data().shape_, info.shape) - << "Incorrect NArray shape" + << "Incorrect NDArray shape" << " Input: " << info.data.data().shape_ << " Desired: " << info.shape; } @@ -420,7 +420,7 @@ void GraphExecutor::InitDataEntryMemory() { } // one pass complete, allocate real memory allocator.InitStorages(); - // get the real data NArray into the DataEntryInfo + // get the real data NDArray into the DataEntryInfo for (size_t i = 0; i < topo_order_.size(); ++i) { uint32_t nid = topo_order_[i]; if (!op_nodes_[nid].activated) continue; @@ -434,7 +434,7 @@ void GraphExecutor::InitDataEntryMemory() { for (StaticGraph::DataEntry e : graph_.heads) { DataEntryInfo &info = op_nodes_[e.source_id].outputs[e.index]; CHECK_EQ(info.type, kInternalAllocated); - heads_narray_.push_back(info.data); + heads_ndarray_.push_back(info.data); } } @@ -518,7 +518,7 @@ void GraphExecutor::Forward(bool is_train) { RunOps(is_train, 0, num_forward_nodes_); } -void GraphExecutor::Backward(const std::vector &head_grads) { +void GraphExecutor::Backward(const std::vector &head_grads) { if (head_grads.size() != 0) { // TODO(bing, min): consider pass a map for backward CHECK_EQ(head_grad_nodes_.size(), head_grads.size()); @@ -545,10 +545,10 @@ void GraphExecutor::Backward(const std::vector &head_grads) { Executor *Executor::Bind(Symbol symbol, Context ctx, - const std::vector &in_args, - const std::vector &arg_grad_store, + const std::vector &in_args, + const std::vector &arg_grad_store, const std::vector &grad_req_type, - const std::vector &aux_states) { + const std::vector &aux_states) { GraphExecutor *exec = new GraphExecutor(); exec->Init(symbol, ctx, in_args, arg_grad_store, grad_req_type, aux_states); return exec; diff --git a/src/symbol/graph_executor.h b/src/symbol/graph_executor.h index 823f28b5398e..25b1ecd3a8bc 100644 --- a/src/symbol/graph_executor.h +++ b/src/symbol/graph_executor.h @@ -21,17 +21,17 @@ class GraphExecutor : public Executor { public: virtual ~GraphExecutor(); virtual void Forward(bool is_train); - virtual void Backward(const std::vector &head_grads); - virtual const std::vector &heads() const { - return heads_narray_; + virtual void Backward(const std::vector &head_grads); + virtual const std::vector &heads() const { + return heads_ndarray_; } // implement Executor::Bind, only call it once. inline void Init(Symbol symbol, Context ctx, - const std::vector &in_args, - const std::vector &arg_grad_store, + const std::vector &in_args, + const std::vector &arg_grad_store, const std::vector &grad_req_type, - const std::vector &aux_states) { + const std::vector &aux_states) { CHECK_EQ(grad_req_type.size(), arg_grad_store.size()); bool need_backward = false; for (auto req : grad_req_type) { @@ -52,9 +52,9 @@ class GraphExecutor : public Executor { class BackwardOpWrapper; // type of data entry enum DataEntryType { - // memory is binded by external NArray in Bind + // memory is binded by external NDArray in Bind kBindByExternal, - // to be binded by external NArray in Forward and Backward + // to be binded by external NDArray in Forward and Backward kTobeBindByExternal, // internal memory, allocated kInternalAllocated, @@ -64,7 +64,7 @@ class GraphExecutor : public Executor { // Additional information about each data entry struct DataEntryInfo { // the actual data for the entry - NArray data; + NDArray data; // write request to this entry OpReqType op_req; // the operatio node that will take @@ -161,11 +161,11 @@ class GraphExecutor : public Executor { // initialize the internal graph structure void InitGraph(Symbol symbol, Context ctx, bool need_backward); // initialize internal DataEntryInfo, reference counting - void InitDataEntryInfo(const std::vector &in_args, - const std::vector &arg_grad_store, + void InitDataEntryInfo(const std::vector &in_args, + const std::vector &arg_grad_store, const std::vector &grad_req_type, - const std::vector &aux_states); - // initialize internal data entries NArray + const std::vector &aux_states); + // initialize internal data entries NDArray void InitDataEntryMemory(); // initialize OpNode data structure void InitOpNodes(); @@ -186,8 +186,8 @@ class GraphExecutor : public Executor { std::vector arg_grads_; // operational nodes std::vector op_nodes_; - // head NArrays - std::vector heads_narray_; + // head NDArrays + std::vector heads_ndarray_; }; // class GraphExecutor } // namespace mxnet #endif // MXNET_SYMBOL_GRAPH_EXECUTOR_H_ diff --git a/src/symbol/graph_memory_allocator.h b/src/symbol/graph_memory_allocator.h index 9c995cd29993..cd6dc0648cb4 100644 --- a/src/symbol/graph_memory_allocator.h +++ b/src/symbol/graph_memory_allocator.h @@ -7,7 +7,7 @@ #define MXNET_SYMBOL_GRAPH_MEMORY_ALLOCATOR_H_ #include -#include +#include #include #include @@ -23,8 +23,8 @@ namespace mxnet { * - Each call to Request will get a ResourceID that is used to * identify the memory block assigned to each DataEntryInfo. * (2) Allocating phase: GraphExecutor call InitMemory. - * - Then each DataEntry will call Get to get the real NArray. - * (3) All the memory will be freed up when reference to all the related NArray ends. + * - Then each DataEntry will call Get to get the real NDArray. + * (3) All the memory will be freed up when reference to all the related NDArray ends. */ class GraphStorageAllocator { public: @@ -37,7 +37,7 @@ class GraphStorageAllocator { /*! * \brief Request a memory. * \param ctx the context of the graph - * \param shape shape of the NArray we want + * \param shape shape of the NDArray we want * \param node_id the node that is requesting the memory, used as hint. */ StorageID Request(Context ctx, TShape shape, uint32_t node_id); @@ -52,9 +52,9 @@ class GraphStorageAllocator { /*! * \brief Get the the memory allocated in planning phase. * \param id the storage id allocated in planning phase. - * \param shape the shape of the NArray requested. + * \param shape the shape of the NDArray requested. */ - NArray Get(StorageID id, TShape shape); + NDArray Get(StorageID id, TShape shape); protected: /*! \brief internal storage entry */ @@ -65,15 +65,15 @@ class GraphStorageAllocator { Context ctx; /*! \brief maximum size of the storage that is requested */ size_t max_size; - /*! \brief the actual NArray to hold the data */ - NArray data; + /*! \brief the actual NDArray to hold the data */ + NDArray data; /*! \brief constructor */ StorageEntry() : max_size(0) {} }; /*! * \brief Allocate a StorageID when Request cannot found existing ones. * \param ctx the context of the graph - * \param shape shape of the NArray we want + * \param shape shape of the NDArray we want */ StorageID Alloc(Context ctx, size_t size); @@ -132,11 +132,11 @@ void GraphStorageAllocator::InitStorages() { for (size_t i = 0; i < data_.size(); ++i) { StorageEntry *e = data_[i].get(); TShape shape = mshadow::Shape1(e->max_size); - e->data = NArray(shape, e->ctx); + e->data = NDArray(shape, e->ctx); } } -NArray GraphStorageAllocator::Get(StorageID id, TShape shape) { +NDArray GraphStorageAllocator::Get(StorageID id, TShape shape) { CHECK_NE(id, kBadStorageID); StorageEntry *e = data_[id].get(); return e->data.Slice(0, shape.Size()).Reshape(shape); diff --git a/tests/python/train/test_conv.py b/tests/python/train/test_conv.py index 4affe6d8f200..93a8b2d806d5 100644 --- a/tests/python/train/test_conv.py +++ b/tests/python/train/test_conv.py @@ -31,9 +31,9 @@ def CalAcc(out, label): data_shape = (batch_size, 1, 28, 28) arg_shapes, out_shapes, aux_shapes = softmax.infer_shape(data=data_shape) -arg_narrays = [mx.narray.empty(shape) for shape in arg_shapes] -grad_narrays = [mx.narray.empty(shape) for shape in arg_shapes] -aux_narrays = [mx.narray.empty(shape) for shape in aux_shapes] +arg_narrays = [mx.nd.empty(shape) for shape in arg_shapes] +grad_narrays = [mx.nd.empty(shape) for shape in arg_shapes] +aux_narrays = [mx.nd.empty(shape) for shape in aux_shapes] inputs = dict(zip(args_list, arg_narrays)) np.random.seed(0) @@ -54,7 +54,7 @@ def CalAcc(out, label): # update out_narray = executor.heads()[0] -grad_narray = mx.narray.empty(out_narray.shape) +grad_narray = mx.nd.empty(out_narray.shape) epoch = 1 momentum = 0.9 diff --git a/tests/python/train/test_mlp.py b/tests/python/train/test_mlp.py index e2b6dcee8488..7198761796a4 100644 --- a/tests/python/train/test_mlp.py +++ b/tests/python/train/test_mlp.py @@ -4,7 +4,7 @@ import os, gzip import pickle as pickle from common import get_data - + def CalAcc(out, label): pred = np.argmax(out, axis=1) return np.sum(pred == label) * 1.0 / out.shape[0] @@ -22,8 +22,8 @@ def CalAcc(out, label): # infer shape data_shape = (batch_size, 784) arg_shapes, out_shapes, aux_shapes = softmax.infer_shape(data=data_shape) -arg_narrays = [mx.narray.empty(shape) for shape in arg_shapes] -grad_narrays = [mx.narray.empty(shape) for shape in arg_shapes] +arg_narrays = [mx.nd.empty(shape) for shape in arg_shapes] +grad_narrays = [mx.nd.empty(shape) for shape in arg_shapes] inputs = dict(zip(args_list, arg_narrays)) np.random.seed(0) # set random weight @@ -39,7 +39,7 @@ def CalAcc(out, label): # update out_narray = executor.heads()[0] -grad_narray = mx.narray.empty(out_narray.shape) +grad_narray = mx.nd.empty(out_narray.shape) epoch = 9 lr = 0.1 diff --git a/tests/python/unittest/test_bind.py b/tests/python/unittest/test_bind.py index 8802eb87d3c2..42f01b0556bb 100644 --- a/tests/python/unittest/test_bind.py +++ b/tests/python/unittest/test_bind.py @@ -16,10 +16,10 @@ def check_bind_with_uniform(uf, gf, dim): rhs = mx.symbol.Variable('rhs') ret = uf(lhs, rhs) assert ret.list_arguments() == ['lhs', 'rhs'] - lhs_arr = mx.narray.array(np.random.uniform(-10, 10, shape)) - rhs_arr = mx.narray.array(np.random.uniform(-10, 10, shape)) - lhs_grad = mx.narray.empty(shape) - rhs_grad = mx.narray.empty(shape) + lhs_arr = mx.nd.array(np.random.uniform(-10, 10, shape)) + rhs_arr = mx.nd.array(np.random.uniform(-10, 10, shape)) + lhs_grad = mx.nd.empty(shape) + rhs_grad = mx.nd.empty(shape) executor = ret.bind(mx.Context('cpu'), @@ -48,7 +48,7 @@ def check_bind_with_uniform(uf, gf, dim): assert reldiff(out1, out3) < 1e-6 assert reldiff(out1, out4) < 1e-6 # test gradient - out_grad = mx.narray.array(np.ones(shape)) + out_grad = mx.nd.array(np.ones(shape)) lhs_grad2, rhs_grad2 = gf(out_grad.asnumpy(), lhs_arr.asnumpy(), rhs_arr.asnumpy()) @@ -78,4 +78,4 @@ def test_bind(): if __name__ == "__main__": - test_bind() \ No newline at end of file + test_bind() diff --git a/tests/python/unittest/test_kvstore.py b/tests/python/unittest/test_kvstore.py index 4152115e6629..72b671c74f5d 100644 --- a/tests/python/unittest/test_kvstore.py +++ b/tests/python/unittest/test_kvstore.py @@ -8,9 +8,9 @@ def init_kvstore(): """init kvstore """ mx.kvstore.start() # single - mx.kvstore.init(3, mx.narray.zeros(shape)) + mx.kvstore.init(3, mx.nd.zeros(shape)) # list - mx.kvstore.init(keys, [mx.narray.zeros(shape)] * len(keys)) + mx.kvstore.init(keys, [mx.nd.zeros(shape)] * len(keys)) def stop_kvstore(): """stop kvstore """ @@ -25,8 +25,8 @@ def test_single_kv_pair(): init_kvstore() - mx.kvstore.push(3, mx.narray.ones(shape)) - val = mx.narray.empty(shape) + mx.kvstore.push(3, mx.nd.ones(shape)) + val = mx.nd.empty(shape) mx.kvstore.pull(3, out = val) check_diff_to_scalar(val, 1) @@ -37,8 +37,8 @@ def test_list_kv_pair(): init_kvstore() - mx.kvstore.push(keys, [mx.narray.ones(shape)*4] * len(keys)) - val = [mx.narray.empty(shape)] * len(keys) + mx.kvstore.push(keys, [mx.nd.ones(shape)*4] * len(keys)) + val = [mx.nd.empty(shape)] * len(keys) mx.kvstore.pull(keys, out = val) for v in val: check_diff_to_scalar(v, 4) @@ -55,7 +55,7 @@ def test_aggregator(): devs = [mx.Context('cpu', i) for i in range(num_devs)] # single - vals = [mx.narray.ones(shape, d) for d in devs] + vals = [mx.nd.ones(shape, d) for d in devs] mx.kvstore.push(3, vals) mx.kvstore.pull(3, out = vals) @@ -64,7 +64,7 @@ def test_aggregator(): check_diff_to_scalar(v, num_devs) # list - vals = [[mx.narray.ones(shape, d)*2.0 for d in devs]] * len(keys) + vals = [[mx.nd.ones(shape, d)*2.0 for d in devs]] * len(keys) mx.kvstore.push(keys, vals) mx.kvstore.pull(keys, out = vals) @@ -89,7 +89,7 @@ def test_updater(dev = 'cpu'): devs = [mx.Context(dev, i) for i in range(num_devs)] # single - vals = [mx.narray.ones(shape, d) for d in devs] + vals = [mx.nd.ones(shape, d) for d in devs] mx.kvstore.push(3, vals) mx.kvstore.pull(3, out = vals) @@ -98,7 +98,7 @@ def test_updater(dev = 'cpu'): check_diff_to_scalar(v, num_devs) # list - vals = [[mx.narray.ones(shape, d) for d in devs]] * len(keys) + vals = [[mx.nd.ones(shape, d) for d in devs]] * len(keys) num_push = 4 for i in range(num_push): diff --git a/tests/python/unittest/test_narray.py b/tests/python/unittest/test_ndarray.py similarity index 87% rename from tests/python/unittest/test_narray.py rename to tests/python/unittest/test_ndarray.py index fd01abca9457..41fb5a632f06 100644 --- a/tests/python/unittest/test_narray.py +++ b/tests/python/unittest/test_ndarray.py @@ -19,7 +19,7 @@ def check_with_uniform(uf, arg_shapes, dim=None): numpy_arg = [] for s in arg_shapes: npy = np.random.uniform(-10, 10, s) - narr = mx.narray.array(npy) + narr = mx.nd.array(npy) narray_arg.append(narr) numpy_arg.append(npy) out1 = uf(*narray_arg) @@ -30,7 +30,7 @@ def check_with_uniform(uf, arg_shapes, dim=None): def random_narray(dim): shape = tuple(np.random.randint(1, int(1000**(1.0/dim)), size=dim)) - data= mx.narray.array(np.random.uniform(-10, 10, shape)) + data= mx.nd.array(np.random.uniform(-10, 10, shape)) return data def test_narray_elementwise(): @@ -45,14 +45,14 @@ def test_narray_elementwise(): check_with_uniform(lambda x, y: x / y, 2, dim) def test_narray_copy(): - c = mx.narray.array(np.random.uniform(-10, 10, (10, 10))) + c = mx.nd.array(np.random.uniform(-10, 10, (10, 10))) d = c.copyto(mx.Context('cpu', 0)) assert np.sum(np.abs(c.asnumpy() != d.asnumpy())) == 0.0 def test_narray_scalar(): - c = mx.narray.empty((10,10)) - d = mx.narray.empty((10,10)) + c = mx.nd.empty((10,10)) + d = mx.nd.empty((10,10)) c[:] = 0.5 d[:] = 1.0 d -= c * 2 / 3 * 6.0 @@ -71,7 +71,7 @@ def test_narray_pickle(): for repeat in range(nrepeat): for dim in range(1, maxdim): a = random_narray(dim) - b = mx.narray.empty(a.shape) + b = mx.nd.empty(a.shape) a[:] = np.random.uniform(-10, 10, a.shape) b[:] = np.random.uniform(-10, 10, a.shape) a = a + b @@ -89,14 +89,14 @@ def test_narray_saveload(): data = [] for i in range(10): data.append(random_narray(np.random.randint(1, 5))) - mx.narray.save(fname, data) - data2 = mx.narray.load(fname) + mx.nd.save(fname, data) + data2 = mx.nd.load(fname) assert len(data) == len(data2) for x, y in zip(data, data2): assert np.sum(x.asnumpy() != y.asnumpy()) == 0 dmap = {'narray xx %s' % i : x for i, x in enumerate(data)} - mx.narray.save(fname, dmap) - dmap2 = mx.narray.load(fname) + mx.nd.save(fname, dmap) + dmap2 = mx.nd.load(fname) assert len(dmap2) == len(dmap) for k, x in dmap.items(): y = dmap2[k] diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 5b3540392dde..dcdf76cf6d9c 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -20,8 +20,8 @@ def check_elementwise_sum_with_shape(shape, n): # forward inputs = [mx.symbol.Variable('arg%d' % i) for i in range(n)] out = mx.symbol.ElementWiseSum(*inputs, name='esum') - arr = [mx.narray.empty(shape) for i in range(n)] - arr_grad = [mx.narray.empty(shape) for i in range(n)] + arr = [mx.nd.empty(shape) for i in range(n)] + arr_grad = [mx.nd.empty(shape) for i in range(n)] for i in range(n): arr[i][:] = np.random.uniform(-10, 10, shape) exec1 = out.bind(mx.Context('cpu'), @@ -32,7 +32,7 @@ def check_elementwise_sum_with_shape(shape, n): out1 = exec1.heads()[0].asnumpy() out = sum(a.asnumpy() for a in arr) assert reldiff(out, out1) < 1e-6 - out_grad = mx.narray.empty(shape) + out_grad = mx.nd.empty(shape) out_grad[:] = np.random.uniform(-10, 10, shape) # backward exec1.backward([out_grad]) @@ -58,14 +58,14 @@ def check_concat_with_shape(shapes): inputs = [mx.symbol.Variable('arg%d' % i) for i in range(n)] out = mx.symbol.Concat(*inputs, name='conc') - arr = [mx.narray.empty(shape) for shape in shapes] + arr = [mx.nd.empty(shape) for shape in shapes] for i in range(n): arr[i][:] = shapes[i][1] arr_np = [np.copy(narray.asnumpy()) for narray in arr] - arr_grad = [mx.narray.empty(shape) for shape in shapes] + arr_grad = [mx.nd.empty(shape) for shape in shapes] args = out.list_arguments() arg_shapes, out_shapes, aux_shapes = out.infer_shape(**dict(zip(args, shapes))) - out_grad = mx.narray.empty(out_shapes[0]) + out_grad = mx.nd.empty(out_shapes[0]) exec1 = out.bind(mx.Context('cpu'), args=arr, args_grad=arr_grad) From da7c7c60c0396cc42b5b86dd07fd3a3f8dd7a3e0 Mon Sep 17 00:00:00 2001 From: tqchen Date: Sun, 13 Sep 2015 10:48:35 -0700 Subject: [PATCH 10/13] update examples --- example/cifar10/cifar10.py | 11 ++++++----- example/mnist/mlp_gpu.py | 12 ++++++------ example/mnist/mlp_multi_gpu.py | 8 ++++---- 3 files changed, 16 insertions(+), 15 deletions(-) diff --git a/example/cifar10/cifar10.py b/example/cifar10/cifar10.py index 9b387b8d297a..b8bf1cb383f4 100644 --- a/example/cifar10/cifar10.py +++ b/example/cifar10/cifar10.py @@ -132,7 +132,7 @@ def RandomInit(narray): in_num = narray.shape[1] out_num = narray.shape[0] a = np.sqrt(3.0 / (in_num + out_num)) - tmp = mx.narray.array(np.random.uniform(-a, a, narray.shape)) + tmp = mx.nd.array(np.random.uniform(-a, a, narray.shape)) narray[:] = tmp data = mx.symbol.Variable(name="data") @@ -161,15 +161,16 @@ def RandomInit(narray): batch_size = 128 data_shape = (batch_size, 3, 28, 28) -in_data = mx.narray.empty(data_shape, mx.gpu()) +in_data = mx.nd.empty(data_shape, mx.gpu()) executor = loss.simple_bind(mx.gpu(), data = in_data) + out_narray = executor.heads()[0] -pred = mx.narray.zeros(out_narray.shape, mx.cpu()) +pred = mx.nd.zeros(out_narray.shape, mx.cpu()) arg_narrays, grad_narrays = executor.list_arguments() inputs = dict(zip(loss.list_arguments(), arg_narrays)) -tmp_label = mx.narray.zeros(inputs["sm_label"].shape) -momentum_narrays = [mx.narray.zeros(item.shape, mx.gpu()) for item in grad_narrays] +tmp_label = mx.nd.zeros(inputs["sm_label"].shape) +momentum_narrays = [mx.nd.zeros(item.shape, mx.gpu()) for item in grad_narrays] block = list(zip(grad_narrays, arg_narrays, momentum_narrays)) diff --git a/example/mnist/mlp_gpu.py b/example/mnist/mlp_gpu.py index 010d09912523..ba4515e2c8b9 100644 --- a/example/mnist/mlp_gpu.py +++ b/example/mnist/mlp_gpu.py @@ -30,20 +30,20 @@ def CalAcc(out, label): arg_shapes, out_shapes, aux_shapes = softmax.infer_shape(data=data_shape) # create GPU NArray for data -arg_narrays = [mx.narray.zeros(shape, ctx=mx.Context("gpu")) for shape in arg_shapes] -grad_narrays = [mx.narray.zeros(shape, ctx=mx.Context("gpu")) for shape in arg_shapes] +arg_narrays = [mx.nd.zeros(shape, ctx=mx.Context("gpu")) for shape in arg_shapes] +grad_narrays = [mx.nd.zeros(shape, ctx=mx.Context("gpu")) for shape in arg_shapes] inputs = dict(zip(args_list, arg_narrays)) # create CPU NArray for result stat name2shape = dict(zip(args_list, arg_shapes)) -pred = mx.narray.zeros(out_shapes[0]) +pred = mx.nd.zeros(out_shapes[0]) # set random weight np.random.seed(0) for name, narray in inputs.items(): if "weight" in name: - tmp = mx.narray.array(np.random.uniform(-0.07, 0.07, name2shape[name])) + tmp = mx.nd.array(np.random.uniform(-0.07, 0.07, name2shape[name])) tmp.copyto(narray) # bind executer @@ -51,7 +51,7 @@ def CalAcc(out, label): executor = softmax.bind(mx.Context('gpu'), arg_narrays, grad_narrays) # create gradient NArray out_narray = executor.heads()[0] -grad_narray = mx.narray.zeros(out_narray.shape, ctx=mx.Context("gpu")) +grad_narray = mx.nd.zeros(out_narray.shape, ctx=mx.Context("gpu")) # update @@ -77,7 +77,7 @@ def Update(grad, weight): label="data/t10k-labels-idx1-ubyte", batch_size=batch_size, shuffle=True, flat=True, silent=False) -tmp_label = mx.narray.zeros(name2shape["sm_label"]) +tmp_label = mx.nd.zeros(name2shape["sm_label"]) def test_mlp(): acc_train = 0. diff --git a/example/mnist/mlp_multi_gpu.py b/example/mnist/mlp_multi_gpu.py index c618e84616aa..9ada84668727 100644 --- a/example/mnist/mlp_multi_gpu.py +++ b/example/mnist/mlp_multi_gpu.py @@ -43,18 +43,18 @@ def updater(key, grad, weight): np.random.seed(0) for idx in sync_indices: shape = param_shapes[idx] - val = mx.narray.zeros(shape) + val = mx.nd.zeros(shape) if "weight" in param_names[idx]: val[:] = np.random.uniform(-0.07, 0.07, shape) mx.kvstore.init(idx, val) # allocate device's memory -params = [[mx.narray.zeros(s, d) for s in param_shapes] for d in devs] -grads = [[mx.narray.zeros(s, d) for s in param_shapes] for d in devs] +params = [[mx.nd.zeros(s, d) for s in param_shapes] for d in devs] +grads = [[mx.nd.zeros(s, d) for s in param_shapes] for d in devs] # create executors for devices executors = [mlp.bind(devs[d], params[d], grads[d]) for d in range(num_devs)] -forward_out = [mx.narray.zeros(e.heads()[0].shape) for e in executors] +forward_out = [mx.nd.zeros(e.heads()[0].shape) for e in executors] # data reader get_data.GetMNIST_ubyte() From 99c797503eda45113e8207bc8ec46ab6e340bbf2 Mon Sep 17 00:00:00 2001 From: tqchen Date: Sun, 13 Sep 2015 11:27:39 -0700 Subject: [PATCH 11/13] fix lint and compile --- python/mxnet/ndarray.py | 18 +++++++++--------- python/mxnet/symbol.py | 4 ++-- src/ndarray/ndarray.cc | 8 ++++---- src/ndarray/ndarray_function.cc | 2 +- 4 files changed, 16 insertions(+), 16 deletions(-) diff --git a/python/mxnet/ndarray.py b/python/mxnet/ndarray.py index 4bd28f814d5d..5b9323639298 100644 --- a/python/mxnet/ndarray.py +++ b/python/mxnet/ndarray.py @@ -160,8 +160,8 @@ def __getstate__(self): length = ctypes.c_ulong() cptr = ctypes.POINTER(ctypes.c_char)() check_call(_LIB.MXNDArraySaveRawBytes(self.handle, - ctypes.byref(length), - ctypes.byref(cptr))) + ctypes.byref(length), + ctypes.byref(cptr))) this['handle'] = ctypes2buffer(cptr, length.value) return this @@ -425,10 +425,10 @@ def load(fname): handles = ctypes.POINTER(NDArrayHandle)() names = ctypes.POINTER(ctypes.c_char_p)() check_call(_LIB.MXNDArrayListLoad(c_str(fname), - ctypes.byref(out_size), - ctypes.byref(handles), - ctypes.byref(out_name_size), - ctypes.byref(names))) + ctypes.byref(out_size), + ctypes.byref(handles), + ctypes.byref(out_name_size), + ctypes.byref(names))) if out_name_size.value == 0: return [NDArray(NDArrayHandle(handles[i])) for i in range(out_size.value)] else: @@ -470,9 +470,9 @@ def save(fname, data): handles.append(val.handle) keys = None check_call(_LIB.MXNDArrayListSave(c_str(fname), - len(handles), - c_array(NDArrayHandle, handles), - keys)) + len(handles), + c_array(NDArrayHandle, handles), + keys)) # pylint: disable=too-many-locals, invalid-name diff --git a/python/mxnet/symbol.py b/python/mxnet/symbol.py index 75a94d0e5ab4..3735ca861f2d 100644 --- a/python/mxnet/symbol.py +++ b/python/mxnet/symbol.py @@ -440,12 +440,12 @@ def bind(self, ctx, args, args_grad=None, grad_req='write', aux_states=None): args_grad_handle = c_array(NDArrayHandle, [None] * len(args)) else: args_grad_handle = self._get_ndarray_handle('args_grad', args_grad, - self.list_arguments(), True) + self.list_arguments(), True) if aux_states is None: aux_states = [] aux_args_handle = self._get_ndarray_handle('aux_states', aux_states, - self.list_auxiliary_states(), False) + self.list_auxiliary_states(), False) # setup requirements req_map = {'null' : 0, 'write' : 1, 'add' : 3} diff --git a/src/ndarray/ndarray.cc b/src/ndarray/ndarray.cc index 8360937f084d..e9be7e445da6 100644 --- a/src/ndarray/ndarray.cc +++ b/src/ndarray/ndarray.cc @@ -1,14 +1,14 @@ /*! * Copyright (c) 2015 by Contributors - * \file narray.cc - * \brief narry module of mxnet + * \file ndarray.cc + * \brief ndarry module of mxnet */ #include #include #include -#include +#include #include -#include "./narray_function.h" +#include "./ndarray_function.h" namespace dmlc { DMLC_REGISTRY_ENABLE(::mxnet::NDArrayFunctionReg); diff --git a/src/ndarray/ndarray_function.cc b/src/ndarray/ndarray_function.cc index a7881907655a..e6dcdcde91b3 100644 --- a/src/ndarray/ndarray_function.cc +++ b/src/ndarray/ndarray_function.cc @@ -1,7 +1,7 @@ /*! * Copyright (c) 2015 by Contributors * \file ndarray_function_cpu.cc - * \brief + * \brief CPU Implementation of ndarray function. */ // this will be invoked by gcc and compile CPU version From 4471fc84b43b3b0b54fac5891b628c0492c41a5a Mon Sep 17 00:00:00 2001 From: tqchen Date: Sun, 13 Sep 2015 11:44:22 -0700 Subject: [PATCH 12/13] rename returns/heads -> outputs --- example/cifar10/cifar10.py | 2 +- example/mnist/mlp_gpu.py | 8 ++++---- include/mxnet/c_api.h | 8 ++++---- include/mxnet/operator.h | 20 ++++++++++---------- include/mxnet/symbolic.h | 10 +++++----- python/mxnet/executor.py | 5 +++-- python/mxnet/symbol.py | 10 +++++----- src/c_api.cc | 12 ++++++------ src/operator/batch_norm-inl.h | 6 +++--- src/symbol/graph_executor.cc | 10 +++++----- src/symbol/graph_executor.h | 2 +- src/symbol/static_graph.cc | 6 +++--- src/symbol/symbol.cc | 18 +++++++++--------- tests/python/train/test_conv.py | 2 +- tests/python/train/test_mlp.py | 2 +- tests/python/unittest/test_bind.py | 6 +++--- tests/python/unittest/test_operator.py | 6 +++--- tests/python/unittest/test_symbol.py | 4 ++-- 18 files changed, 69 insertions(+), 68 deletions(-) diff --git a/example/cifar10/cifar10.py b/example/cifar10/cifar10.py index b8bf1cb383f4..ce8aa2c8823e 100644 --- a/example/cifar10/cifar10.py +++ b/example/cifar10/cifar10.py @@ -164,7 +164,7 @@ def RandomInit(narray): in_data = mx.nd.empty(data_shape, mx.gpu()) executor = loss.simple_bind(mx.gpu(), data = in_data) -out_narray = executor.heads()[0] +out_narray = executor.outputs[0] pred = mx.nd.zeros(out_narray.shape, mx.cpu()) arg_narrays, grad_narrays = executor.list_arguments() diff --git a/example/mnist/mlp_gpu.py b/example/mnist/mlp_gpu.py index ba4515e2c8b9..a801476decfe 100644 --- a/example/mnist/mlp_gpu.py +++ b/example/mnist/mlp_gpu.py @@ -30,8 +30,8 @@ def CalAcc(out, label): arg_shapes, out_shapes, aux_shapes = softmax.infer_shape(data=data_shape) # create GPU NArray for data -arg_narrays = [mx.nd.zeros(shape, ctx=mx.Context("gpu")) for shape in arg_shapes] -grad_narrays = [mx.nd.zeros(shape, ctx=mx.Context("gpu")) for shape in arg_shapes] +arg_narrays = [mx.nd.zeros(shape, ctx=mx.gpu()) for shape in arg_shapes] +grad_narrays = [mx.nd.zeros(shape, ctx=mx.gpu()) for shape in arg_shapes] inputs = dict(zip(args_list, arg_narrays)) # create CPU NArray for result stat @@ -50,8 +50,8 @@ def CalAcc(out, label): # TODO(bing): think of a better bind interface executor = softmax.bind(mx.Context('gpu'), arg_narrays, grad_narrays) # create gradient NArray -out_narray = executor.heads()[0] -grad_narray = mx.nd.zeros(out_narray.shape, ctx=mx.Context("gpu")) +out_narray = executor.outputs[0] +grad_narray = mx.nd.zeros(out_narray.shape, ctx=mx.gpu()) # update diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h index 494807a47fcc..043ad050d23f 100644 --- a/include/mxnet/c_api.h +++ b/include/mxnet/c_api.h @@ -388,7 +388,7 @@ MXNET_DLL int MXSymbolListArguments(SymbolHandle symbol, * \param out_str_array pointer to hold the output string array * \return 0 when success, -1 when failure happens */ -MXNET_DLL int MXSymbolListReturns(SymbolHandle symbol, +MXNET_DLL int MXSymbolListOutputs(SymbolHandle symbol, mx_uint *out_size, const char ***out_str_array); /*! @@ -502,9 +502,9 @@ MXNET_DLL int MXExecutorBackward(ExecutorHandle handle, * \param out out put narray handles * \return 0 when success, -1 when failure happens */ -MXNET_DLL int MXExecutorHeads(ExecutorHandle handle, - mx_uint *out_size, - NDArrayHandle **out); +MXNET_DLL int MXExecutorOutputs(ExecutorHandle handle, + mx_uint *out_size, + NDArrayHandle **out); /*! * \brief Generate Executor from symbol diff --git a/include/mxnet/operator.h b/include/mxnet/operator.h index ae6f6af45df8..92b5f034a3c9 100644 --- a/include/mxnet/operator.h +++ b/include/mxnet/operator.h @@ -94,8 +94,8 @@ class Operator { * * \note * Convention: - * out_grad.size() == OperatorProperty.NumVisibleReturns() - * out_data.size() == OperatorProperty.NumReturns() + * out_grad.size() == OperatorProperty.NumVisibleOutputs() + * out_data.size() == OperatorProperty.NumOutputs() * out_data can contain additional invisible returns that remembers the * state carried from the Forward pass. For example mask in the dropout. * The gradients are passed from visible returns in this function. @@ -157,10 +157,10 @@ class OperatorProperty { return {"data"}; } /*! - * \brief Get name of return values of Operator - * \return name of return values. + * \brief Get name of output values of Operator + * \return name of output values. */ - virtual std::vector ListReturns() const { + virtual std::vector ListOutputs() const { return {"output"}; } /*! @@ -171,23 +171,23 @@ class OperatorProperty { return {}; } /*! \return number of real return values of the Operator */ - virtual int NumReturns() const { + virtual int NumOutputs() const { return 1; } /*! * \brief get number of visible return values during Symbol creation. - * If NumVisibleReturns() = k, and NumReturns() = n. + * If NumVisibleOutputs() = k, and NumOutputs() = n. * The first k returns will be presented in the resulting symbol. * * The rest of the returns can be used for auxiliary states for Backward. - * For example, Dropout will return [data, mask], with NumVisibleReturns() == 1. + * For example, Dropout will return [data, mask], with NumVisibleOutputs() == 1. * So when user call sym = Dropout(input), only data is presented in sym. * But all the returns will be presented in out_data parameter of Backward if requested. * * \return number of default return values */ - virtual int NumVisibleReturns() const { - return NumReturns(); + virtual int NumVisibleOutputs() const { + return NumOutputs(); } /*! * \brief infer the shapes of outputs and unknown input arguments diff --git a/include/mxnet/symbolic.h b/include/mxnet/symbolic.h index e496ff42a673..9f1f21d69d59 100644 --- a/include/mxnet/symbolic.h +++ b/include/mxnet/symbolic.h @@ -208,7 +208,7 @@ class Symbol { */ std::vector ListArguments() const; /*! \return get the descriptions of outputs for this symbol */ - std::vector ListReturns() const; + std::vector ListOutputs() const; /*! \return get the descriptions of auxiliary data for this symbol */ std::vector ListAuxiliaryStates() const; /*! @@ -303,7 +303,7 @@ class Symbol { * \brief get number of outputs of this symbol * \return number of outputs */ - inline size_t NumReturns() const { + inline size_t NumOutputs() const { return heads_.size(); } /*! @@ -401,10 +401,10 @@ class Executor { */ virtual void Backward(const std::vector &head_grads) = 0; /*! - * \brief get array of heads in the executor. - * \return array of heads in the executor. + * \brief get array of outputs in the executor. + * \return array of outputs in the executor. */ - virtual const std::vector &heads() const = 0; + virtual const std::vector &outputs() const = 0; /*! * \brief Create an operator by bind symbol with context and arguments. * If user do not want to compute the gradients of i-th argument, grad_req_type[i] can be kNullOp. diff --git a/python/mxnet/executor.py b/python/mxnet/executor.py index 189461074ffc..7c7ea5db449e 100644 --- a/python/mxnet/executor.py +++ b/python/mxnet/executor.py @@ -81,7 +81,8 @@ def backward(self, head_grads=None): ndarray = c_array(NDArrayHandle, [item.handle for item in head_grads]) check_call(_LIB.MXExecutorBackward(self.handle, len(head_grads), ndarray)) - def heads(self): + @property + def outputs(self): """list all heads' output ndarray Returns @@ -94,5 +95,5 @@ def heads(self): # if user set the content of the head, the backward behavior can be incorrect. out_size = mx_uint() handles = ctypes.POINTER(NDArrayHandle)() - check_call(_LIB.MXExecutorHeads(self.handle, ctypes.byref(out_size), ctypes.byref(handles))) + check_call(_LIB.MXExecutorOutputs(self.handle, ctypes.byref(out_size), ctypes.byref(handles))) return [NDArray(NDArrayHandle(handles[i])) for i in range(out_size.value)] diff --git a/python/mxnet/symbol.py b/python/mxnet/symbol.py index 3735ca861f2d..179c4d6ca214 100644 --- a/python/mxnet/symbol.py +++ b/python/mxnet/symbol.py @@ -145,17 +145,17 @@ def list_arguments(self): self.handle, ctypes.byref(size), ctypes.byref(sarr))) return [py_str(sarr[i]) for i in range(size.value)] - def list_returns(self): - """List all returns in the symbol. + def list_outputs(self): + """List all outputs in the symbol. Returns ------- returns : list of string - List of all the returns. + List of all the outputs. """ size = ctypes.c_uint() sarr = ctypes.POINTER(ctypes.c_char_p)() - check_call(_LIB.MXSymbolListReturns( + check_call(_LIB.MXSymbolListOutputs( self.handle, ctypes.byref(size), ctypes.byref(sarr))) return [py_str(sarr[i]) for i in range(size.value)] @@ -203,7 +203,7 @@ def infer_shape(self, *args, **kwargs): The order is in the same order as list_arguments() out_shapes : list of tuple or None List of shapes of outputs. - The order is in the same order as list_returns() + The order is in the same order as list_outputs() aux_shapes : list of tuple or None List of shapes of outputs. The order is in the same order as list_auxiliary() diff --git a/src/c_api.cc b/src/c_api.cc index 8fedf170b3f2..4a96d946d34f 100644 --- a/src/c_api.cc +++ b/src/c_api.cc @@ -556,13 +556,13 @@ int MXSymbolListArguments(SymbolHandle symbol, API_END(); } -int MXSymbolListReturns(SymbolHandle symbol, +int MXSymbolListOutputs(SymbolHandle symbol, mx_uint *out_size, const char ***out_str_array) { Symbol *s = static_cast(symbol); MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get(); API_BEGIN(); - ret->ret_vec_str = std::move(s->ListReturns()); + ret->ret_vec_str = std::move(s->ListOutputs()); ret->ret_vec_charp.clear(); for (size_t i = 0; i < ret->ret_vec_str.size(); ++i) { ret->ret_vec_charp.push_back(ret->ret_vec_str[i].c_str()); @@ -705,13 +705,13 @@ int MXExecutorBackward(ExecutorHandle handle, API_END(); } -int MXExecutorHeads(ExecutorHandle handle, - mx_uint *out_size, - NDArrayHandle **out) { +int MXExecutorOutputs(ExecutorHandle handle, + mx_uint *out_size, + NDArrayHandle **out) { MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get(); API_BEGIN(); Executor *exec = static_cast(handle); - std::vector heads = exec->heads(); + std::vector heads = exec->outputs(); ret->ret_handles.resize(heads.size()); for (size_t i = 0; i < heads.size(); ++i) { NDArray *ptr = new NDArray(); diff --git a/src/operator/batch_norm-inl.h b/src/operator/batch_norm-inl.h index 1409615e853a..613913eb8284 100644 --- a/src/operator/batch_norm-inl.h +++ b/src/operator/batch_norm-inl.h @@ -239,11 +239,11 @@ class BatchNormProp : public OperatorProperty { return {{out_grad[kOut], in_grad[kData]}}; } - int NumVisibleReturns() const override { + int NumVisibleOutputs() const override { return 1; } - int NumReturns() const override { + int NumOutputs() const override { return 4; } @@ -251,7 +251,7 @@ class BatchNormProp : public OperatorProperty { return {"data", "gamma", "beta"}; } - std::vector ListReturns() const override { + std::vector ListOutputs() const override { return {"output", "output_no_affine", "mean", "var"}; } diff --git a/src/symbol/graph_executor.cc b/src/symbol/graph_executor.cc index 7839f01164a8..10318e39beb1 100644 --- a/src/symbol/graph_executor.cc +++ b/src/symbol/graph_executor.cc @@ -23,9 +23,9 @@ class GraphExecutor::BackwardOpWrapper : public Operator { explicit BackwardOpWrapper(const OperatorProperty *prop, std::shared_ptr forward_op) : op_(forward_op) { - out_grad_.resize(prop->NumVisibleReturns()); + out_grad_.resize(prop->NumVisibleOutputs()); in_data_.resize(prop->ListArguments().size()); - out_data_.resize(prop->NumReturns()); + out_data_.resize(prop->NumOutputs()); std::vector out_grad_ptr(out_grad_.size()); for (size_t i = 0; i < out_grad_.size(); ++i) { @@ -88,7 +88,7 @@ GraphExecutor::GetResource(uint32_t node_id) const { inline int GraphExecutor::GetNumOutputs(uint32_t node_id) const { const StaticGraph::Node &node = graph_.nodes[node_id]; if (node.is_forward()) { - return node.op->NumReturns(); + return node.op->NumOutputs(); } else if (node.is_backward()) { return static_cast( graph_.nodes[node.backward_source_id].op->ListArguments().size()); @@ -128,9 +128,9 @@ inline std::vector > GraphExecutor::GetInplaceOption( // forward property const OperatorProperty *fwd = graph_.nodes[node.backward_source_id].op.get(); - std::vector out_grad_index(fwd->NumVisibleReturns()); + std::vector out_grad_index(fwd->NumVisibleOutputs()); std::vector in_data_index(fwd->ListArguments().size()); - std::vector out_data_index(fwd->NumReturns()); + std::vector out_data_index(fwd->NumOutputs()); CHECK_EQ(in_data_index.size(), out_data.size()); int counter = 0; for (size_t i = 0; i < out_grad_index.size(); ++i) { diff --git a/src/symbol/graph_executor.h b/src/symbol/graph_executor.h index 25b1ecd3a8bc..af752776df48 100644 --- a/src/symbol/graph_executor.h +++ b/src/symbol/graph_executor.h @@ -22,7 +22,7 @@ class GraphExecutor : public Executor { virtual ~GraphExecutor(); virtual void Forward(bool is_train); virtual void Backward(const std::vector &head_grads); - virtual const std::vector &heads() const { + virtual const std::vector &outputs() const { return heads_ndarray_; } // implement Executor::Bind, only call it once. diff --git a/src/symbol/static_graph.cc b/src/symbol/static_graph.cc index c24ee1a085d5..7cd6f7f2147b 100644 --- a/src/symbol/static_graph.cc +++ b/src/symbol/static_graph.cc @@ -133,7 +133,7 @@ bool StaticGraph::InferShape(std::vector *in_shape, for (size_t i = 0; i < nodes.size(); ++i) { int nout = 1; if (nodes[i].is_forward()) { - nout = nodes[i].op->NumReturns(); + nout = nodes[i].op->NumOutputs(); } else if (nodes[i].is_backward()) { nout = static_cast(nodes[nodes[i].backward_source_id].inputs.size()); } @@ -215,9 +215,9 @@ void StaticGraph::MakeBackwardPass(std::vector *head_grad_nodes, // get out_grad and out_data entry std::vector out_grad, out_data; // nvisible is out_grad.size() - int nvisible = nodes[nid].op->NumVisibleReturns(); + int nvisible = nodes[nid].op->NumVisibleOutputs(); // ntotal is out_data.size() - int ntotal = nodes[nid].op->NumReturns(); + int ntotal = nodes[nid].op->NumOutputs(); // check all outpus for (int i = 0; i < ntotal; ++i) { DataEntry odata(nid, static_cast(i)); diff --git a/src/symbol/symbol.cc b/src/symbol/symbol.cc index 6daa30ef21d1..769cc5361f54 100644 --- a/src/symbol/symbol.cc +++ b/src/symbol/symbol.cc @@ -192,7 +192,7 @@ std::vector Symbol::ListArguments() const { } } -std::vector Symbol::ListReturns() const { +std::vector Symbol::ListOutputs() const { std::vector ret; for (auto &head : heads_) { if (head.source->is_variable()) { @@ -200,7 +200,7 @@ std::vector Symbol::ListReturns() const { } else { // TODO(bing) rethink about output naming auto &hname = head.source->name; - std::string rname = head.source->op->ListReturns()[head.index]; + std::string rname = head.source->op->ListOutputs()[head.index]; if (hname.length() == 0) { ret.push_back(std::move(rname)); } else { @@ -233,7 +233,7 @@ std::vector Symbol::ListAuxiliaryStates() const { } Symbol Symbol::operator[] (size_t index) const { - size_t nreturn = NumReturns(); + size_t nreturn = NumOutputs(); CHECK_LT(index, nreturn) << "Symbol only accept nonnegative index"; if (nreturn == 1) { return *this; @@ -246,12 +246,12 @@ Symbol Symbol::operator[] (size_t index) const { void Symbol::Compose(const std::vector& args, const std::string& name) { - CHECK_EQ(NumReturns(), 1) << "Only composition of value function is supported currently"; + CHECK_EQ(NumOutputs(), 1) << "Only composition of value function is supported currently"; CHECK(!heads_[0].source->is_variable()) << "Variable cannot be composed"; heads_[0].source->name = name; for (size_t i = 0; i < args.size(); ++i) { - CHECK_EQ(args[i].NumReturns(), 1) - << "Argument " << i << " is a tuple with " << args[i].NumReturns() + CHECK_EQ(args[i].NumOutputs(), 1) + << "Argument " << i << " is a tuple with " << args[i].NumOutputs() << " elements, scalar is required"; } // positional arguments requires all arguments for now. @@ -305,11 +305,11 @@ void Symbol::Compose(const std::vector& args, void Symbol::Compose(const std::unordered_map& kwargs, const std::string& name) { - CHECK_EQ(NumReturns(), 1) << "Only composition of value function is supported currently"; + CHECK_EQ(NumOutputs(), 1) << "Only composition of value function is supported currently"; CHECK(!heads_[0].source->is_variable()) << "Variable cannot be composed"; heads_[0].source->name = name; for (const auto& kv : kwargs) { - CHECK_EQ(kv.second.NumReturns(), 1) + CHECK_EQ(kv.second.NumOutputs(), 1) << "Keyword Argument " << kv.first << " is a tuple, scalar is required"; } size_t nmatched = 0; @@ -483,7 +483,7 @@ bool Symbol::InferShape(const std::unordered_map& known_arg Symbol Symbol::Create(OperatorProperty *op) { // use special representation for atomic symbol auto node = std::make_shared(op, ""); - size_t nret = op->NumVisibleReturns(); + size_t nret = op->NumVisibleOutputs(); Symbol s; for (uint32_t i = 0; i < nret; ++i) { s.heads_.push_back(DataEntry(node, i)); diff --git a/tests/python/train/test_conv.py b/tests/python/train/test_conv.py index 93a8b2d806d5..d9f737402e7a 100644 --- a/tests/python/train/test_conv.py +++ b/tests/python/train/test_conv.py @@ -53,7 +53,7 @@ def CalAcc(out, label): executor = softmax.bind(mx.Context('cpu'), arg_narrays, grad_narrays, 'write', aux_narrays) # update -out_narray = executor.heads()[0] +out_narray = executor.outputs[0] grad_narray = mx.nd.empty(out_narray.shape) epoch = 1 diff --git a/tests/python/train/test_mlp.py b/tests/python/train/test_mlp.py index 7198761796a4..315564ea5057 100644 --- a/tests/python/train/test_mlp.py +++ b/tests/python/train/test_mlp.py @@ -38,7 +38,7 @@ def CalAcc(out, label): executor = softmax.bind(mx.Context('cpu'), arg_narrays, grad_narrays) # update -out_narray = executor.heads()[0] +out_narray = executor.outputs[0] grad_narray = mx.nd.empty(out_narray.shape) epoch = 9 diff --git a/tests/python/unittest/test_bind.py b/tests/python/unittest/test_bind.py index 42f01b0556bb..0b34307e8a00 100644 --- a/tests/python/unittest/test_bind.py +++ b/tests/python/unittest/test_bind.py @@ -40,10 +40,10 @@ def check_bind_with_uniform(uf, gf, dim): executor.forward() exec3.forward() exec4.forward() - out2 = executor.heads()[0].asnumpy() + out2 = executor.outputs[0].asnumpy() out1 = uf(lhs_arr.asnumpy(), rhs_arr.asnumpy()) - out3 = exec3.heads()[0].asnumpy() - out4 = exec4.heads()[0].asnumpy() + out3 = exec3.outputs[0].asnumpy() + out4 = exec4.outputs[0].asnumpy() assert reldiff(out1, out2) < 1e-6 assert reldiff(out1, out3) < 1e-6 assert reldiff(out1, out4) < 1e-6 diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index dcdf76cf6d9c..2c07bb370ab6 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -27,9 +27,9 @@ def check_elementwise_sum_with_shape(shape, n): exec1 = out.bind(mx.Context('cpu'), args=arr, args_grad=arr_grad) - out1 = exec1.heads()[0].asnumpy() + out1 = exec1.outputs[0].asnumpy() exec1.forward() - out1 = exec1.heads()[0].asnumpy() + out1 = exec1.outputs[0].asnumpy() out = sum(a.asnumpy() for a in arr) assert reldiff(out, out1) < 1e-6 out_grad = mx.nd.empty(shape) @@ -70,7 +70,7 @@ def check_concat_with_shape(shapes): args=arr, args_grad=arr_grad) exec1.forward() - out1 = exec1.heads()[0] + out1 = exec1.outputs[0] ret = np.concatenate([narray.asnumpy() for narray in arr], axis=1) assert same(out1.asnumpy(), ret) # backward diff --git a/tests/python/unittest/test_symbol.py b/tests/python/unittest/test_symbol.py index b4dc93e1cfdd..9e29f0b9ffb5 100644 --- a/tests/python/unittest/test_symbol.py +++ b/tests/python/unittest/test_symbol.py @@ -6,7 +6,7 @@ def test_symbol_basic(): mlist.append(models.mlp2()) for m in mlist: m.list_arguments() - m.list_returns() + m.list_outputs() def test_compose(): @@ -25,4 +25,4 @@ def test_compose(): composed = net2(fc3_data=net1, name='composed') print(composed.debug_str()) multi_out = mx.symbol.Group([composed, net1]) - assert len(multi_out.list_returns()) == 2 + assert len(multi_out.list_outputs()) == 2 From cba8358f2779ffa6f2287cc4f8c103bb36097d7a Mon Sep 17 00:00:00 2001 From: tqchen Date: Sun, 13 Sep 2015 11:50:01 -0700 Subject: [PATCH 13/13] fix lint --- python/mxnet/executor.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/mxnet/executor.py b/python/mxnet/executor.py index 7c7ea5db449e..f57077adc919 100644 --- a/python/mxnet/executor.py +++ b/python/mxnet/executor.py @@ -95,5 +95,6 @@ def outputs(self): # if user set the content of the head, the backward behavior can be incorrect. out_size = mx_uint() handles = ctypes.POINTER(NDArrayHandle)() - check_call(_LIB.MXExecutorOutputs(self.handle, ctypes.byref(out_size), ctypes.byref(handles))) + check_call(_LIB.MXExecutorOutputs(self.handle, + ctypes.byref(out_size), ctypes.byref(handles))) return [NDArray(NDArrayHandle(handles[i])) for i in range(out_size.value)]