From a91970702677c57410a3886154b738abbf77e212 Mon Sep 17 00:00:00 2001 From: winsty Date: Tue, 14 Jul 2015 20:17:39 +0800 Subject: [PATCH 1/3] narray operator --- Makefile | 3 +- include/mxnet/base.h | 34 +++++++ include/mxnet/narray.h | 4 + include/mxnet/narray_operator.h | 100 +++++++++++++++++++ include/mxnet/operator.h | 32 ------- src/narray_operator/narray_operator.cc | 127 +++++++++++++++++++++++++ src/operator/operator_common.h | 12 +-- 7 files changed, 273 insertions(+), 39 deletions(-) create mode 100644 include/mxnet/narray_operator.h create mode 100644 src/narray_operator/narray_operator.cc diff --git a/Makefile b/Makefile index 73e085051a8f..2a2c8398af56 100644 --- a/Makefile +++ b/Makefile @@ -57,7 +57,7 @@ endif BIN = test/api_registry_test OBJ = storage.o narray_op_cpu.o operator.o operator_cpu.o # add threaded engine after it is done -OBJCXX11 = engine.o narray.o mxnet_api.o api_registry.o +OBJCXX11 = engine.o narray.o mxnet_api.o api_registry.o narray_operator.o CUOBJ = SLIB = api/libmxnet.so ALIB = api/libmxnet.a @@ -85,6 +85,7 @@ operator_cpu.o: src/operator/operator_cpu.cc operator_gpu.o: src/operator/operator_gpu.cu api_registry.o: src/api_registry.cc mxnet_api.o: api/mxnet_api.cc +narray_operator.o: src/narray_operator/narray_operator.cc api/libmxnet.a: $(OBJ) $(OBJCXX11) $(CUOBJ) api/libmxnet.so: $(OBJ) $(OBJCXX11) $(CUOBJ) diff --git a/include/mxnet/base.h b/include/mxnet/base.h index ca9da665124a..67c3a1b24b74 100644 --- a/include/mxnet/base.h +++ b/include/mxnet/base.h @@ -40,5 +40,39 @@ typedef mshadow::gpu gpu; typedef mshadow::index_t index_t; /*! \brief data type that will be used to store ndarray */ typedef mshadow::default_real_t real_t; + +/*! \brief option to pass into the forward function */ +struct Option { + /*! \brief whether it is training phase*/ + int is_train; +}; +/*! \brief gradient request type the request can have */ +enum GradReqType { + /*! \brief no operation, do not write gradient */ + kNullOp = 0, + /*! \brief write gradient to provided space */ + kWriteTo = 1, + /*! \brief same as kWriteTo, but provided space is same as space of input-data */ + kWriteInplace = 2, + /*! \brief add to the provided space */ + kAddTo = 3 +}; +/*! \brief input argument type of the operator have */ +enum ArgType { + /*! \brief data argument */ + kDataArg = 0, + /*! \brief weight argument */ + kWeightArg = 1, + /*! \brief bias argument */ + kBiasArg = 2 +}; +/*! \brief Property for engine schedule */ +enum Property { + /*! \brief Op contains interanl state, won't influence engine schedule */ + kContainInteralState = 1, + /*! \brief Op forward require random number, will influence engine schedule */ + kForwardRequireRnd = 2, +}; + } // namespace mxnet #endif // MXNET_BASE_H_ diff --git a/include/mxnet/narray.h b/include/mxnet/narray.h index 637826ed5a09..458bf9f6c834 100644 --- a/include/mxnet/narray.h +++ b/include/mxnet/narray.h @@ -72,6 +72,10 @@ class NArray { if (is_none()) return; DAGEngine::Get()->WaitForVar(ptr_->var); } + /*! \return the associated DAG variable of the narray.*/ + inline DAGEngine::Variable Var() const { + return ptr_->var; + } /*! * \brief set all the elements in narray to be scalar * \param scalar the scalar to set diff --git a/include/mxnet/narray_operator.h b/include/mxnet/narray_operator.h new file mode 100644 index 000000000000..e55ecc193453 --- /dev/null +++ b/include/mxnet/narray_operator.h @@ -0,0 +1,100 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file narray_operator.h + * \brief narray operator interface of mxnet + * \author Naiyan Wang + */ +#ifndef MXNET_NARRAY_OPERATOR_H_ +#define MXNET_NARRAY_OPERATOR_H_ +// this file will be seen by cuda, no c++11 for now +#include +#include +#include "./base.h" +#include "./tensor_blob.h" +#include "./operator.h" +#include "./narray.h" +#include "./dag_engine.h" + +namespace mxnet { +/*! + * \brief static operator interface (current interface have not yet todo with scheduler), + * operator is a stateful object that can be used to call forward and backprop + * + * This interface relies on pre-allocated memory in TBlob, the caller need to set + * the memory region in TBlob correctly before calling Forward and Backward + * + * \sa Operator + */ +class NArrayOperator { + public: + NArrayOperator(Operator* op, Context ctx); + /*! + * \brief get types of input argument of this oeprator + * \return a vector corresponding to type of each argument + * this order is same as the order of inputs in Forward, InferShape and Backward + */ + virtual std::vector DescribeArgs() const; + /*! + * \brief describe property of op + * \return a bit map in int + */ + virtual int DescribeProperty() const; + /*! + * \brief set param for the operator from string + * \param name parameter name + * \param val string for configuration + */ + virtual void SetParam(const char *name, const char *val); + /*! + * \brief inter the shapes of outputs and unknown input arguments + * \param in_shape the shape of input arguments of the operator + * this should be of same length as the vector returned by DescribeArgs + * in_shape allows unknown elements, which are checked by shape.ndim() == 0. + * For unknown shapes, InferShape will try to fill in the correct Shape in in_shape + * For known shapes, InferShape will check shape consistency + * + * common practice: set the shape of data input, and usually weight's shape can be infered + * + * \param out_shape the shape of outputs of the operator + * InferShape will modify the vector to fill output TShape + */ + virtual void InferShape(std::vector *in_shape, + std::vector *out_shape); + + virtual void SetContext(Context ctx); + /*! + * \brief perform a forward operation of operator, save the output to TBlob + * \param opt option on Forward such as whether this is training phase + * \param ctx runtime context + * \param in_data array of input data, it is const + * \param out_data array of output data, + * the space of TBlob in out_data must be pre-allocated with InferShape + */ + virtual void Forward(Option opt, + RunContext ctx, + const std::vector &in_data, + const std::vector &out_data); + /*! + * \brief perform a backward operation of the operator to get the gradient + * \param ctx runtime context + * \param grad_next the gradient value we get from output of the operator + * \param in_data the array of input data + * \param out_grad array of output gradient, there could be three possible TBlob + * in the each element in the array + * \param req request types of the gradient saving operation + * only inplace will change input data + * \sa GradReqType + */ + virtual void Backward(RunContext ctx, + const std::vector &grad_next, + const std::vector &in_data, + const std::vector &out_grad, + const std::vector &req); + +private: + /* \brief the static operator */ + Operator* op; + Context global_ctx; +}; +} // namespace mxnet +#endif // MXNET_NARRAY_OPERATOR_H_ diff --git a/include/mxnet/operator.h b/include/mxnet/operator.h index aed0f63349b7..a2796537f3a9 100644 --- a/include/mxnet/operator.h +++ b/include/mxnet/operator.h @@ -23,38 +23,6 @@ namespace mxnet { */ class Operator { public: - /*! \brief option to pass into the forward function */ - struct Option { - /*! \brief whether it is training phase*/ - int is_train; - }; - /*! \brief gradient request type the request can have */ - enum GradReqType { - /*! \brief no operation, do not write gradient */ - kNullOp = 0, - /*! \brief write gradient to provided space */ - kWriteTo = 1, - /*! \brief same as kWriteTo, but provided space is same as space of input-data */ - kWriteInplace = 2, - /*! \brief add to the provided space */ - kAddTo = 3 - }; - /*! \brief input argument type of the operator have */ - enum ArgType { - /*! \brief data argument */ - kDataArg = 0, - /*! \brief weight argument */ - kWeightArg = 1, - /*! \brief bias argument */ - kBiasArg = 2 - }; - /*! \brief Property for engine schedule */ - enum Property { - /*! \brief Op contains interanl state, won't influence engine schedule */ - kContainInteralState = 1, - /*! \brief Op forward require random number, will influence engine schedule */ - kForwardRequireRnd = 2, - }; /*! * \brief get types of input argument of this oeprator * \return a vector corresponding to type of each argument diff --git a/src/narray_operator/narray_operator.cc b/src/narray_operator/narray_operator.cc new file mode 100644 index 000000000000..7eaf934e1ec0 --- /dev/null +++ b/src/narray_operator/narray_operator.cc @@ -0,0 +1,127 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file narray_operator.cc + * \brief the implementation of narray operator + * \author Naiyan Wang + */ +#include + +namespace mxnet { + + NArrayOperator::NArrayOperator(Operator* op, Context ctx) { + this->op = op; + this->global_ctx = ctx; + } + /*! + * \brief get types of input argument of this oeprator + * \return a vector corresponding to type of each argument + * this order is same as the order of inputs in Forward, InferShape and Backward + */ + std::vector NArrayOperator::DescribeArgs() const { + // default most of layers only have one data argument + return op->DescribeArgs(); + } + /*! + * \brief describe property of op + * \return a bit map in int + */ + int NArrayOperator::DescribeProperty() const { + // default most of layer only conatin internal state + return op->DescribeProperty(); + } + /*! + * \brief set param for the operator from string + * \param name parameter name + * \param val string for configuration + */ + void NArrayOperator::SetParam(const char *name, const char *val) { + op->SetParam(name, val); + } + /*! + * \brief inter the shapes of outputs and unknown input arguments + * \param in_shape the shape of input arguments of the operator + * this should be of same length as the vector returned by DescribeArgs + * in_shape allows unknown elements, which are checked by shape.ndim() == 0. + * For unknown shapes, InferShape will try to fill in the correct Shape in in_shape + * For known shapes, InferShape will check shape consistency + * + * common practice: set the shape of data input, and usually weight's shape can be infered + * + * \param out_shape the shape of outputs of the operator + * InferShape will modify the vector to fill output TShape + */ + void NArrayOperator::InferShape(std::vector *in_shape, + std::vector *out_shape) { + op->InferShape(in_shape, out_shape); + } + /*! + * \brief perform a forward operation of operator, save the output to TBlob + * \param opt option on Forward such as whether this is training phase + * \param ctx runtime context + * \param in_data array of input data, it is const + * \param out_data array of output data, + * the space of TBlob in out_data must be pre-allocated with InferShape + */ + void NArrayOperator::Forward(Option opt, + RunContext ctx, + const std::vector &in_data, + const std::vector &out_data) { + std::vector used_var; + std::vector mutate_var; + std::vector in; + std::vector out; + for (size_t i = 0; i < in_data.size(); ++i) { + used_var.push_back(in_data[i].Var()); + in.push_back(in_data[i].data()); + } + for (size_t i = 0; i < out_data.size(); ++i) { + mutate_var.push_back(out_data[i].Var()); + out.push_back(out_data[i].data()); + } + DAGEngine::Get()->Push([this, opt, ctx, in, out](RunContext ctx) { + op->Forward(opt, ctx, in, out); + }, global_ctx, used_var, mutate_var); + } + /*! + * \brief perform a backward operation of the operator to get the gradient + * \param ctx runtime context + * \param grad_next the gradient value we get from output of the operator + * \param in_data the array of input data + * \param out_grad array of output gradient, there could be three possible TBlob + * in the each element in the array + * \param req request types of the gradient saving operation + * only inplace will change input data + * \sa GradReqType + */ + void NArrayOperator::Backward(RunContext ctx, + const std::vector &grad_next, + const std::vector &in_data, + const std::vector &out_grad, + const std::vector &req) { + std::vector used_var; + std::vector mutate_var; + std::vector grad_in; + std::vector grad_out; + std::vector data; + for (size_t i = 0; i < grad_next.size(); ++i) { + used_var.push_back(grad_next[i].Var()); + grad_in.push_back(grad_next[i].data()); + } + for (size_t i = 0; i < in_data.size(); ++i) { + used_var.push_back(in_data[i].Var()); + data.push_back(in_data[i].data()); + } + for (size_t i = 0; i < out_grad.size(); ++i) { + mutate_var.push_back(out_grad[i].Var()); + grad_out.push_back(out_grad[i].data()); + } + DAGEngine::Get()->Push([this, ctx, grad_in, grad_out, data, req](RunContext ctx) { + op->Backward(ctx, grad_in, data, grad_out, req); + }, global_ctx, used_var, mutate_var); + } + + void NArrayOperator::SetContext(Context ctx) { + this->global_ctx = ctx; + } + +} // namespace mxnet \ No newline at end of file diff --git a/src/operator/operator_common.h b/src/operator/operator_common.h index ce65100f4c39..f87ffdbb7efb 100644 --- a/src/operator/operator_common.h +++ b/src/operator/operator_common.h @@ -11,7 +11,7 @@ #include #include - +#include namespace mxnet { namespace op { /*! @@ -24,13 +24,13 @@ namespace op { */ template inline void Assign(OType &out, // NOLINT(*) - Operator::GradReqType req, + GradReqType req, const Exp &exp) { switch (req) { - case Operator::kNullOp: break; - case Operator::kWriteTo: - case Operator::kWriteInplace: out = exp; break; - case Operator::kAddTo: out += exp; break; + case kNullOp: break; + case kWriteTo: + case kWriteInplace: out = exp; break; + case kAddTo: out += exp; break; default: LOG(FATAL) << "not reached"; } } From 3a3003c8200a423dc55b09a74fac3950440d765e Mon Sep 17 00:00:00 2001 From: winsty Date: Wed, 15 Jul 2015 19:15:13 +0800 Subject: [PATCH 2/3] merge makefile --- Makefile | 8 +- api/mxnet_api.cc | 75 ++++++++++++ api/mxnet_api.h | 70 ++++++++++-- api/python/mxnet/__init__.py | 3 + api/python/mxnet/base.py | 3 +- api/python/mxnet/symbol.py | 73 ++++++++++++ api/python/mxnet/symbol_creator.py | 69 +++++++++++ doc/Doxyfile | 2 +- include/mxnet/api_registry.h | 66 +++++++++++ include/mxnet/atomic_symbol.h | 74 ++++++++++++ include/mxnet/symbol.h | 151 +++++++++++++++++++++++++ src/api_registry.cc | 22 ++++ src/common/concurrent_blocking_queue.h | 54 +++++++-- src/common/spin_lock.h | 29 ++++- src/dag_engine/threaded_engine.cc | 71 ++++++------ src/symbol/symbol.cc | 141 +++++++++++++++++++++++ test/test_threaded_engine.cc | 35 +++++- 17 files changed, 884 insertions(+), 62 deletions(-) create mode 100644 api/python/mxnet/symbol.py create mode 100644 api/python/mxnet/symbol_creator.py create mode 100644 include/mxnet/atomic_symbol.h create mode 100644 include/mxnet/symbol.h create mode 100644 src/symbol/symbol.cc diff --git a/Makefile b/Makefile index 2a2c8398af56..460bda8452df 100644 --- a/Makefile +++ b/Makefile @@ -54,10 +54,11 @@ ifneq ($(ADD_LDFLAGS), NONE) LDFLAGS += $(ADD_LDFLAGS) endif +#BIN = test/test_threaded_engine test/api_registry_test BIN = test/api_registry_test OBJ = storage.o narray_op_cpu.o operator.o operator_cpu.o # add threaded engine after it is done -OBJCXX11 = engine.o narray.o mxnet_api.o api_registry.o narray_operator.o +OBJCXX11 = engine.o narray.o mxnet_api.o api_registry.o symbol.o narray_operator.o CUOBJ = SLIB = api/libmxnet.so ALIB = api/libmxnet.a @@ -76,13 +77,14 @@ $(DMLC_CORE)/libdmlc.a: storage.o: src/storage/storage.cc engine.o: src/dag_engine/simple_engine.cc -threaded_engine.o: src/dag_engine/threaded_engine.cc src/common/concurrent_blocking_queue.h src/common/spin_lock.h +#engine.o: src/dag_engine/threaded_engine.cc src/common/concurrent_blocking_queue.h src/common/spin_lock.h narray.o: src/narray/narray.cc narray_op_cpu.o: src/narray/narray_op_cpu.cc src/narray/narray_op-inl.h narray_op_gpu.o: src/narray/narray_op_gpu.cu src/narray/narray_op-inl.h operator.o: src/operator/operator.cc operator_cpu.o: src/operator/operator_cpu.cc operator_gpu.o: src/operator/operator_gpu.cu +symbol.o: src/symbol/symbol.cc api_registry.o: src/api_registry.cc mxnet_api.o: api/mxnet_api.cc narray_operator.o: src/narray_operator/narray_operator.cc @@ -91,6 +93,7 @@ api/libmxnet.a: $(OBJ) $(OBJCXX11) $(CUOBJ) api/libmxnet.so: $(OBJ) $(OBJCXX11) $(CUOBJ) test/api_registry_test: test/api_registry_test.cc api/libmxnet.a +#test/test_threaded_engine: test/test_threaded_engine.cc api/libmxnet.a $(BIN) : $(CXX) $(CFLAGS) -std=c++0x -o $@ $(filter %.cpp %.o %.c %.a %.cc, $^) $(LDFLAGS) @@ -123,4 +126,3 @@ doc: clean: $(RM) $(OBJ) $(OBJCXX11) $(BIN) $(CUBIN) $(CUOBJ) $(SLIB) $(ALIB) *~ */*~ */*/*~ */*/*/*~ cd $(DMLC_CORE); make clean; cd - - diff --git a/api/mxnet_api.cc b/api/mxnet_api.cc index 0d0575ba488c..7e351e52cb03 100644 --- a/api/mxnet_api.cc +++ b/api/mxnet_api.cc @@ -246,3 +246,78 @@ int MXFuncInvoke(FunctionHandle fun, (NArray**)(mutate_vars)); // NOLINT(*) API_END(); } + +int MXSymFree(SymbolHandle sym) { + API_BEGIN(); + delete static_cast(sym); + API_END(); +} + +int MXSymCreatorDescribe(SymbolCreatorHandle sym_creator, + mx_uint *use_param) { + API_BEGIN(); + auto *sc = static_cast(sym_creator); + *use_param = sc->use_param ? 1 : 0; + API_END(); +} + +int MXSymCreatorInvoke(SymbolCreatorHandle sym_creator, + int count, + const char** keys, + const char** vals, + SymbolHandle* out) { + API_BEGIN(); + const SymbolCreatorRegistry::Entry *sc = + static_cast(sym_creator); + sc->body(count, keys, vals, (Symbol**)(out)); // NOLINT(*) + API_END(); +} + +int MXListSymCreators(mx_uint *out_size, + SymbolCreatorHandle **out_array) { + API_BEGIN(); + auto &vec = SymbolCreatorRegistry::List(); + *out_size = static_cast(vec.size()); + *out_array = (SymbolCreatorHandle*)(dmlc::BeginPtr(vec)); // NOLINT(*) + API_END(); +} + +int MXGetSymCreator(const char *name, + SymbolCreatorHandle *out) { + API_BEGIN(); + *out = SymbolCreatorRegistry::Find(name); + API_END(); +} + +int MXSymCreatorGetName(SymbolCreatorHandle sym_creator, + const char **out_name) { + API_BEGIN(); + auto *f = static_cast(sym_creator); + *out_name = f->name.c_str(); + API_END(); +} + +int MXSymbolCompose(SymbolHandle sym, + mx_uint num_args, + const char** keys, + SymbolHandle* args, + SymbolHandle* out) { + API_BEGIN(); + const Symbol* s = static_cast(sym); + Symbol* ret = new Symbol; + if (keys == NULL) { + std::vector pos_args; + for (mx_uint i = 0; i < num_args; ++i) { + pos_args.push_back(*(Symbol*)(args[i])); // NOLINT(*) + } + *ret = (*s)(pos_args); + } else { + std::unordered_map kwargs; + for (mx_uint i = 0; i < num_args; ++i) { + kwargs[keys[i]] = *(Symbol*)(args[i]); // NOLINT(*) + } + *ret = (*s)(kwargs); + } + *out = ret; + API_END(); +} diff --git a/api/mxnet_api.h b/api/mxnet_api.h index d30a18d571dd..fb4b8710e2e6 100644 --- a/api/mxnet_api.h +++ b/api/mxnet_api.h @@ -28,6 +28,8 @@ typedef float mx_float; typedef void *NArrayHandle; /*! \brief handle to a mxnet narray function that changes NArray */ typedef const void *FunctionHandle; +/*! \brief handle to a function that takes param and creates symbol */ +typedef const void *SymbolCreatorHandle; /*! \brief handle to a symbol that can be bind as operator */ typedef void *SymbolHandle; /*! \brief handle to a NArrayOperator */ @@ -217,17 +219,69 @@ MXNET_DLL int MXSymCreateFromConfig(const char *cfg, * \param sym the symbol * \return 0 when success, -1 when failure happens */ -MXNET_DLL int MXSymFree(SymbolHandle *sym); +MXNET_DLL int MXSymFree(SymbolHandle sym); /*! - * \brief set the parameter in to current symbol - * \param sym the symbol - * \param name name of the parameter - * \param val value of the parameter + * \brief query if the symbol creator needs param. + * \param sym_creator the symbol creator handle + * \param use_param describe if the symbol creator requires param + * \return 0 when success, -1 when failure happens + */ +MXNET_DLL int MXSymCreatorDescribe(SymbolCreatorHandle sym_creator, + mx_uint *use_param); +/*! + * \brief invoke registered symbol creator through its handle. + * \param sym_creator pointer to the symbolcreator function. + * \param count the number of the key value pairs in the param. + * \param keys an array of c str. + * \param vals the corresponding values of the keys. + * \param out pointer to the created symbol handle + * \return 0 when success, -1 when failure happens + */ +MXNET_DLL int MXSymCreatorInvoke(SymbolCreatorHandle sym_creator, + int count, + const char** keys, + const char** vals, + SymbolHandle* out); +/*! + * \brief list all the available sym_creator + * most user can use it to list all the needed sym_creators + * \param out_size the size of returned array + * \param out_array the output sym_creators + * \return 0 when success, -1 when failure happens + */ +MXNET_DLL int MXListSymCreators(mx_uint *out_size, + SymbolCreatorHandle **out_array); +/*! + * \brief get the sym_creator by name + * \param name the name of the sym_creator + * \param out the corresponding sym_creator + * \return 0 when success, -1 when failure happens + */ +MXNET_DLL int MXGetSymCreator(const char *name, + SymbolCreatorHandle *out); +/*! + * \brief get the name of sym_creator handle + * \param sym_creator the sym_creator handle + * \param out_name the name of the sym_creator * \return 0 when success, -1 when failure happens */ -MXNET_DLL int MXSymSetParam(SymbolHandle sym, - const char *name, - const char *val); +MXNET_DLL int MXSymCreatorGetName(SymbolCreatorHandle sym_creator, + const char **out_name); +/*! + * \brief compose the symbol on other symbol + * \param sym the symbol to apply + * \param num_args number of arguments + * \param keys the key of keyword args (optional) + * \param args arguments to sym + * \param out the resulting symbol + * \return 0 when success, -1 when failure happens + */ +MXNET_DLL int MXSymbolCompose(SymbolHandle sym, + mx_uint num_args, + const char** keys, + SymbolHandle* args, + SymbolHandle* out); + //-------------------------------------------- // Part 4: operator interface on NArray //-------------------------------------------- diff --git a/api/python/mxnet/__init__.py b/api/python/mxnet/__init__.py index 28b1659efb75..dabf0b795412 100644 --- a/api/python/mxnet/__init__.py +++ b/api/python/mxnet/__init__.py @@ -13,6 +13,9 @@ from .context import Context, current_context from .narray import NArray from .function import _FunctionRegistry +from .symbol import Symbol +from .symbol_creator import _SymbolCreatorRegistry # this is a global function registry that can be used to invoke functions op = NArray._init_function_registry(_FunctionRegistry()) +sym = Symbol._init_symbol_creator_registry(_SymbolCreatorRegistry()) diff --git a/api/python/mxnet/base.py b/api/python/mxnet/base.py index e7e20394f738..239e284293b4 100644 --- a/api/python/mxnet/base.py +++ b/api/python/mxnet/base.py @@ -55,6 +55,8 @@ def _load_lib(): mx_float = ctypes.c_float NArrayHandle = ctypes.c_void_p FunctionHandle = ctypes.c_void_p +SymbolCreatorHandle = ctypes.c_void_p +SymbolHandle = ctypes.c_void_p #---------------------------- # helper function definition @@ -131,4 +133,3 @@ def ctypes2numpy_shared(cptr, shape): size *= s dbuffer = (mx_float * size).from_address(ctypes.addressof(cptr.contents)) return np.frombuffer(dbuffer, dtype=np.float32).reshape(shape) - diff --git a/api/python/mxnet/symbol.py b/api/python/mxnet/symbol.py new file mode 100644 index 000000000000..6f4146d162e3 --- /dev/null +++ b/api/python/mxnet/symbol.py @@ -0,0 +1,73 @@ +# coding: utf-8 +"""Symbol support of mxnet""" +from __future__ import absolute_import + +import ctypes +from .base import _LIB +from .base import c_array, c_str +from .base import SymbolHandle +from .base import check_call + +class Symbol(object): + """SymbolCreator is a function that takes Param and return symbol""" + _registry = None + + @staticmethod + def _init_symbol_creator_registry(symbol_creator_registry): + """Initialize symbol creator registry + + Parameters + ---------- + symbol_creator_registry: + pass in symbol_creator_registry + Returns + ------- + the passed in registry + """ + _registry = symbol_creator_registry + return _registry + + def __init__(self, handle): + """Initialize the function with handle + + Parameters + ---------- + handle : SymbolHandle + the handle to the underlying C++ Symbol + """ + self.handle = handle + + def __call__(self, *args, **kwargs): + """Compose Symbols + + Parameters + ---------- + args: + provide positional arguments + kwargs: + provide keyword arguments + Returns + ------- + the resulting symbol + """ + assert (len(args) == 0 or len(kwargs) == 0) + for arg in args: + assert isinstance(arg, Symbol) + for _, val in kwargs: + assert isinstance(val, Symbol) + num_args = len(args) + len(kwargs) + if len(kwargs) != 0: + keys = c_array(ctypes.c_char_p, [c_str(key) for key in kwargs.keys()]) + args = c_array(SymbolHandle, kwargs.values()) + else: + keys = None + args = c_array(SymbolHandle, args) + + out = SymbolHandle() + check_call(_LIB.MXSymbolCompose( + self.handle, + num_args, + keys, + args, + ctypes.byref(out))) + return Symbol(out) diff --git a/api/python/mxnet/symbol_creator.py b/api/python/mxnet/symbol_creator.py new file mode 100644 index 000000000000..e8a49149ec35 --- /dev/null +++ b/api/python/mxnet/symbol_creator.py @@ -0,0 +1,69 @@ +# coding: utf-8 +"""Symbol support of mxnet""" +from __future__ import absolute_import + +import ctypes +from .base import _LIB +from .base import c_array, c_str +from .base import mx_uint, SymbolHandle +from .base import check_call +from .symbol import Symbol + +class _SymbolCreator(object): + """SymbolCreator is a function that takes Param and return symbol""" + + def __init__(self, handle, name): + """Initialize the function with handle + + Parameters + ---------- + handle : SymbolCreatorHandle + the function handle of the function + + name : string + the name of the function + """ + self.handle = handle + self.name = name + use_param = mx_uint() + check_call(_LIB.MXSymCreatorDescribe( + self.handle, + ctypes.byref(use_param))) + self.use_param = use_param.value + + def __call__(self, **kwargs): + """Invoke creator of symbol by passing kwargs + + Parameters + ---------- + params : kwargs + provide the params necessary for the symbol creation + Returns + ------- + the resulting symbol + """ + keys = c_array(ctypes.c_char_p, [c_str(key) for key in kwargs.keys()]) + vals = c_array(ctypes.c_char_p, [c_str(str(val)) for val in kwargs.values()]) + sym_handle = SymbolHandle() + check_call(_LIB.MXSymCreatorInvoke( + self.handle, + mx_uint(len(kwargs)), + keys, + vals, + ctypes.byref(sym_handle))) + return Symbol(sym_handle) + +class _SymbolCreatorRegistry(object): + """Function Registry""" + def __init__(self): + plist = ctypes.POINTER(ctypes.c_void_p)() + size = ctypes.c_uint() + check_call(_LIB.MXListSymCreators(ctypes.byref(size), + ctypes.byref(plist))) + hmap = {} + for i in range(size.value): + hdl = plist[i] + name = ctypes.c_char_p() + check_call(_LIB.MXSymCreatorGetName(hdl, ctypes.byref(name))) + hmap[name.value] = _SymbolCreator(hdl, name.value) + self.__dict__.update(hmap) diff --git a/doc/Doxyfile b/doc/Doxyfile index b3d9d7fdbb81..f1f8f62bf4c0 100644 --- a/doc/Doxyfile +++ b/doc/Doxyfile @@ -773,7 +773,7 @@ INPUT_ENCODING = UTF-8 # *.md, *.mm, *.dox, *.py, *.f90, *.f, *.for, *.tcl, *.vhd, *.vhdl, *.ucf, # *.qsf, *.as and *.js. -FILE_PATTERNS = *.cc *.h +FILE_PATTERNS = *.h # The RECURSIVE tag can be used to specify whether or not subdirectories should # be searched for input files as well. diff --git a/include/mxnet/api_registry.h b/include/mxnet/api_registry.h index 403201f93ac4..91083c2ea11d 100644 --- a/include/mxnet/api_registry.h +++ b/include/mxnet/api_registry.h @@ -17,6 +17,7 @@ #include #include "./base.h" #include "./narray.h" +#include "./symbol.h" namespace mxnet { @@ -211,5 +212,70 @@ class FunctionRegistry { static auto __ ## name ## _narray_fun__ = \ ::mxnet::FunctionRegistry::Get()->Register("" # name) +/*! \brief registry of symbol creator */ +class SymbolCreatorRegistry { + public: + /*! \brief SymbolCreator is a function pointer */ + typedef void(*SymbolCreator)(int count, const char**, const char**, Symbol**); + /*! \return get a singleton */ + static SymbolCreatorRegistry *Get(); + /*! \brief keep the SymbolCreator function and its meta information */ + struct Entry { + /*! \brief the name of the symbol creator */ + std::string name; + /*! \brief the body of the function */ + SymbolCreator body; + /*! \brief if the creator requires params to construct */ + bool use_param; + /*! \brief constructor */ + explicit Entry(const std::string& name) : name(name), body(nullptr), use_param(true) {} + /*! \brief setter of body */ + inline Entry& set_body(SymbolCreator sc) { body = sc; return *this; } + /*! \brief setter of use_param */ + inline Entry& set_use_param(bool up) { use_param = up; return *this; } + }; + /*! + * \brief register a name symbol under name + * \param name name of the function + * \return ref to the registered entry, used to set properties + */ + Entry &Register(const std::string& name); + /*! \return list of functions in the registry */ + inline static const std::vector &List() { + return Get()->fun_list_; + } + /*! + * \brief find an symbolcreator entry with corresponding name + * \param name name of the symbolcreator + * \return the corresponding symbolcreator, can be NULL + */ + inline static const Entry *Find(const std::string &name) { + auto &fmap = Get()->fmap_; + auto p = fmap.find(name); + if (p != fmap.end()) { + return p->second; + } else { + return nullptr; + } + } + + private: + /*! \brief list of functions */ + std::vector fun_list_; + /*! \brief map of name->function */ + std::map fmap_; + /*! \brief constructor */ + SymbolCreatorRegistry() {} + /*! \brief destructor */ + ~SymbolCreatorRegistry(); +}; + +/*! + * \brief macro to register symbol creator + */ +#define REGISTER_SYMBOL_CREATOR(name) \ + static auto __ ## name ## _symbol_creator__ = \ + ::mxnet::SymbolCreatorRegistry::Get()->Register("" # name) + } // namespace mxnet #endif // MXNET_API_REGISTRY_H_ diff --git a/include/mxnet/atomic_symbol.h b/include/mxnet/atomic_symbol.h new file mode 100644 index 000000000000..086bba9c6bae --- /dev/null +++ b/include/mxnet/atomic_symbol.h @@ -0,0 +1,74 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file atomic_symbol.h + * \brief atomic symbol interface of mxnet + */ +#ifndef MXNET_ATOMIC_SYMBOL_H_ +#define MXNET_ATOMIC_SYMBOL_H_ + +#include +#include +#include +#include +#include "./base.h" +#include "./tensor_blob.h" + +namespace mxnet { +class Operator; +/*! + * \brief AtomicSymbol is the base class of all atomic symbols. + * This is not meant to be used by user, it should be wrapped in Symbol, so that the same instance + * of AtomicSymbol can be shared in the graphs of different Symbols + */ +class AtomicSymbol { + public: + /*! + * \brief Constructor with param as the argument. + */ + AtomicSymbol(); + /*! + * \brief virtual destructor + */ + virtual ~AtomicSymbol(); + /*! \brief get the descriptions of inputs for this symbol */ + virtual std::vector DescribeArguments() const = 0; + /*! \brief get the descriptions of outputs for this symbol */ + virtual std::vector DescribeReturns() const = 0; + /*! + * \brief set param for the symbol from string + * \param name parameter name + * \param val string for the configuration + */ + virtual void SetParam(const char *name, const char *val) {} + /*! + * \brief infer the shapes of outputs and unknown input arguments + * \param in_shape the shape of input arguments of the operator + * this should be of same length as the vector returned by DescribeArgs + * in_shape allows unknown elements, which are checked by shape.ndim() == 0. + * For unknown shapes, InferShape will try to fill in the correct Shape in in_shape + * For known shapes, InferShape will check shape consistency + * + * common practice: set the shape of data input, and usually weight's shape can be infered + * + * \param out_shape the shape of outputs of the operator + * InferShape will modify the vector to fill output TShape + * \return if the shape inference is successful, return true, else return false. + */ + virtual bool InferShape(std::vector *in_shape, std::vector *out_shape) = 0; + /*! + * \brief Copy this AtomicSymbol and returns a pointer to the copied object. + * this is a virtual function because different subclass of AtomicSymbol would copy differently. + * \return a pointer to the copied atomic symbol + */ + virtual AtomicSymbol* Copy() const = 0; + /*! + * \brief Bind this AtomicSymbol to a context and get back a static operator + * Bind function of AtomicSymbol does not return NArrayOperator, but static operator. + * Calling bind from the Symbol wrapper would generate a NArrayOperator. + */ + virtual Operator* Bind(Context ctx) const = 0; + friend class Symbol; +}; + +} // namespace mxnet +#endif // MXNET_ATOMIC_SYMBOL_H_ diff --git a/include/mxnet/symbol.h b/include/mxnet/symbol.h new file mode 100644 index 000000000000..a623cfa1502e --- /dev/null +++ b/include/mxnet/symbol.h @@ -0,0 +1,151 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file symbol.h + * \brief symbol interface of mxnet + */ +#ifndef MXNET_SYMBOL_H_ +#define MXNET_SYMBOL_H_ + +#include +#include +#include +#include +#include +#include +#include "./base.h" +#include "./tensor_blob.h" + +namespace mxnet { +class NArrayOperator; +/*! + * \brief Symbol is the wrapper of AtomicSymbol, the reason for this indirection is that Symbol + * should support expressions and often passed by value. While AtomicSymbol have many subclasses, + * passing by value would result in object slicing. + * + * Symbol is always composite, the head Node is the output node of the symbol. + * A atomic symbol can be seen as a special case of the composite symbol with only the head node. + */ +class Symbol { + protected: + /*! + * \brief Node is the container of AtomicSymbol, it also stores the connection of the AtomicSymbol + * with input symbols. + */ + struct Node { + /*! wrapped atomic symbol */ + AtomicSymbol* sym_; + /*! name of the node */ + std::string name_; + /*! inputs to this node */ + std::vector > in_symbol_; + /*! index of the inputs if the inputs are tuple */ + std::vector in_index_; + /*! the output shape of the wrapped symbol */ + std::vector out_shape_; + /*! + * \brief constructor + */ + explicit Node(AtomicSymbol* sym = nullptr, const std::string& name = ""); + /*! + * \brief destructor + */ + ~Node(); + }; + /*! \brief the head node of the Symbol, it could be shared in many graphs */ + std::shared_ptr head_; + /*! \brief if the head has multiple return values, index is used to specify */ + int index_; + /*! \brief find the nodes that use placeholder arguments */ + std::shared_ptr > > arg_users_; + /*! \brief find arg users */ + void FindArgUsers(); + + public: + /*! + * \brief bind to device and returns an NArrayOperator. + * \param ctx context of the operator + * \return returns the pointer to a created NArrayOperator. It is on the user to delete. + */ + virtual NArrayOperator* Bind(Context ctx) const { return nullptr; } + /*! + * \brief copy the symbol + * \return a deep copy of the graph + */ + virtual Symbol Copy() const; + /*! + * \brief compose with arguments + * \param args positional arguments for the symbol + * \return a new Symbol which is the composition of current symbol with its arguments + */ + virtual Symbol operator () (const std::vector& args) const; + /*! + * \brief compose with named arguments + * \param kwargs keyword arguments for the symbol + * \return a new symbol which is the composition of current symbol with its arguments + */ + virtual Symbol operator () (const std::unordered_map& kwargs) const; + /*! + * \brief get the index th element from the returned tuple. + */ + virtual Symbol operator[] (int index) const; + /*! + * \brief arguments information + * \return the arguments list of this symbol, they can be either named or unnamed (empty string). + */ + virtual std::vector ListArgs(); + /*! + * \brief create atomic symbol wrapped in symbol + * \param param the parameter stored as key value pairs + * \return the constructed Symbol + */ + template + static Symbol CreateSymbol(const std::vector >& param) { + Symbol* s; + std::vector keys(param.size()); + std::vector vals(param.size()); + for (auto p : param) { + keys.push_back(p.first.c_str()); + vals.push_back(p.second.c_str()); + } + CreateSymbol(param.size(), &keys[0], &vals[0], &s); + Symbol ret = *s; + delete s; + return ret; + } + /*! + * \brief c api for CreateSymbol, this can be registered with SymbolCreatorRegistry + * \param num_param the number of params + * \param keys the key for the params + * \param vals values of the params + * \param out stores the returning symbol + */ + template + friend void CreateSymbol(int num_param, const char** keys, const char** vals, Symbol** out); +}; + +template +void CreateSymbol(int num_param, const char** keys, const char** vals, Symbol** out) { + Symbol* s = new Symbol; + Atomic* atom = new Atomic; + for (int i = 0; i < num_param; ++i) { + atom->SetParam(keys[i], vals[i]); + } + std::vector args = atom->DescribeArguments(); + std::vector rets = atom->DescribeReturns(); + // set head_ + s->head_ = std::make_shared(atom, ""); + // set index_ + s->index_ = rets.size() > 1 ? -1 : 0; + // set head_->in_index_ + s->head_->in_index_ = std::vector(args.size(), 0); + // set head_->in_symbol_ + for (auto name : args) { + s->head_->in_symbol_.push_back(std::make_shared(nullptr, name)); + } + // set head_->out_shape_ + s->head_->out_shape_ = std::vector(rets.size()); + *out = s; +} + +} // namespace mxnet +#endif // MXNET_SYMBOL_H_ diff --git a/src/api_registry.cc b/src/api_registry.cc index 0a6423441bbe..c029502287e9 100644 --- a/src/api_registry.cc +++ b/src/api_registry.cc @@ -29,4 +29,26 @@ FunctionRegistry *FunctionRegistry::Get() { return &instance; } +// SymbolCreatorRegistry + +SymbolCreatorRegistry::Entry& +SymbolCreatorRegistry::Register(const std::string& name) { + CHECK_EQ(fmap_.count(name), 0); + Entry *e = new Entry(name); + fmap_[name] = e; + fun_list_.push_back(e); + return *e; +} + +SymbolCreatorRegistry::~SymbolCreatorRegistry() { + for (auto p = fmap_.begin(); p != fmap_.end(); ++p) { + delete p->second; + } +} + +SymbolCreatorRegistry *SymbolCreatorRegistry::Get() { + static SymbolCreatorRegistry instance; + return &instance; +} + } // namespace mxnet diff --git a/src/common/concurrent_blocking_queue.h b/src/common/concurrent_blocking_queue.h index aab39895b119..14bab00d8280 100644 --- a/src/common/concurrent_blocking_queue.h +++ b/src/common/concurrent_blocking_queue.h @@ -1,4 +1,11 @@ -#pragma once +/*! + * Copyright (c) 2015 by Contributors + * \file concurrent_blocking_queue.h + * \brief A simple lock-based consumer-producer queue. + */ +#ifndef MXNET_COMMON_CONCURRENT_BLOCKING_QUEUE_H_ +#define MXNET_COMMON_CONCURRENT_BLOCKING_QUEUE_H_ + #include #include #include @@ -6,11 +13,21 @@ #include #include +namespace common { + +/*! + * \brief A simple lock-based consumer-producer queue. + */ template class ConcurrentBlockingQueue { - const static int BUSY_LOOP = 1000; + static const int kBusyLoop = 1000; + public: ConcurrentBlockingQueue() : has_elmt_(false), exit_now_(false) { } + /*! + * \brief Push object into the queue. Notify anyone who is waiting. + * \param e the object + */ void Push(const T& e) { std::lock_guard lock(mutex_); has_elmt_ = true; @@ -19,15 +36,22 @@ template class ConcurrentBlockingQueue { cv_.notify_all(); } } - bool Pop(T& rv) { - for (int i = 0; i < BUSY_LOOP; i++) { + /*! + * \brief Pop object out of the queue. If the queue is empty, the caller thread will sleep until + * (1) Producer pushed some product into the queue and the caller thread wins it. + * (2) A kill signal is passed to the queue. + * \param rv the pointer point to the return object + * \return whether an object is returned + */ + bool Pop(T* rv) { + for (int i = 0; i < kBusyLoop; i++) { if (has_elmt_) { std::lock_guard lock(mutex_); if (!has_elmt_) { assert(queue_.empty()); continue; } - rv = queue_.front(); + *rv = queue_.front(); queue_.pop_front(); if (queue_.empty()) has_elmt_ = false; @@ -40,28 +64,38 @@ template class ConcurrentBlockingQueue { cv_.wait(lock); } if (!exit_now_) { - rv = queue_.front(); + *rv = queue_.front(); queue_.pop_front(); if (queue_.empty()) has_elmt_ = false; return false; } else { - return true; + return true; } } } + /*! + * \brief pop all objects in the queue. + * \return a list containing all objects in the queue. + */ std::list PopAll() { std::lock_guard lock(mutex_); std::list rv; rv.swap(queue_); return rv; } - // Call `SignalForKill` before destruction + /*! + * \brief tell the queue to release all waiting consumers + */ void SignalForKill() { std::unique_lock lock(mutex_); exit_now_ = true; cv_.notify_all(); } + /*! + * \brief return the current queue size + * \return queue size + */ size_t QueueSize() { std::unique_lock lock(mutex_); return queue_.size(); @@ -77,3 +111,7 @@ template class ConcurrentBlockingQueue { ConcurrentBlockingQueue(const ConcurrentBlockingQueue&) = delete; ConcurrentBlockingQueue& operator=(const ConcurrentBlockingQueue&) = delete; }; + +} // namespace common + +#endif // MXNET_COMMON_CONCURRENT_BLOCKING_QUEUE_H_ diff --git a/src/common/spin_lock.h b/src/common/spin_lock.h index 5a0cc3f786e6..60850f171ecf 100644 --- a/src/common/spin_lock.h +++ b/src/common/spin_lock.h @@ -1,17 +1,18 @@ -#ifndef _SPINLOCK_XCHG_H -#define _SPINLOCK_XCHG_H - -/* Spin lock using xchg. +/* Copyright (c) 2015 by Contributors + * Spin lock using xchg. * Copied from http://locklessinc.com/articles/locks/ */ +#ifndef MXNET_COMMON_SPIN_LOCK_H_ +#define MXNET_COMMON_SPIN_LOCK_H_ + /* Compile read-write barrier */ #define barrier() asm volatile("": : :"memory") /* Pause instruction to prevent excess processor bus usage */ #define cpu_relax() asm volatile("pause\n": : :"memory") -static inline unsigned short xchg_8(void *ptr, unsigned char x) { +static inline unsigned short xchg_8(void *ptr, unsigned char x) { // NOLINT(*) __asm__ __volatile__("xchgb %0,%1" :"=r" (x) :"m" (*(volatile unsigned char *)ptr), "0" (x) @@ -23,8 +24,15 @@ static inline unsigned short xchg_8(void *ptr, unsigned char x) { #define BUSY 1 typedef unsigned char spinlock; +/*! + * \brief use this value to initialize lock object + */ #define SPINLOCK_INITIALIZER 0 +/*! + * \brief lock + * \param lock the pointer to lock object + */ static inline void spin_lock(spinlock *lock) { while (1) { if (!xchg_8(lock, BUSY)) return; @@ -33,13 +41,22 @@ static inline void spin_lock(spinlock *lock) { } } +/*! + * \brief unlock + * \param lock the pointer to lock object + */ static inline void spin_unlock(spinlock *lock) { barrier(); *lock = 0; } +/*! + * \brief try lock + * \param lock the pointer to lock object + * \return whether the lock is grabbed or not + */ static inline int spin_trylock(spinlock *lock) { return xchg_8(lock, BUSY); } -#endif /* _SPINLOCK_XCHG_H */ +#endif // MXNET_COMMON_SPIN_LOCK_H_ diff --git a/src/dag_engine/threaded_engine.cc b/src/dag_engine/threaded_engine.cc index 143b5e72f413..e5b44d5d1db2 100644 --- a/src/dag_engine/threaded_engine.cc +++ b/src/dag_engine/threaded_engine.cc @@ -1,3 +1,4 @@ +// Copyright (c) 2015 by Contributors #include #include #include @@ -6,8 +7,8 @@ #include #include -#include -#include +#include "dmlc/logging.h" +#include "mxnet/dag_engine.h" #include "../common/spin_lock.h" #include "../common/concurrent_blocking_queue.h" @@ -19,14 +20,14 @@ namespace mxnet { class ThreadedEngine : public DAGEngine { public: - ThreadedEngine(int numthreads = DEFAULT_NUM_WORKER_THREADS): numthreads_(numthreads) { - for(int i = 0; i < numthreads; ++i) { + explicit ThreadedEngine(int numthreads = DEFAULT_NUM_WORKER_THREADS): numthreads_(numthreads) { + for (int i = 0; i < numthreads; ++i) { worker_queues_.push_back(new ConcurrentBlockingQueue()); workers_.emplace_back(&ThreadedEngine::WorkerRoutine, this, i); } } ~ThreadedEngine() { - for(int i = 0; i < numthreads_; ++i) { + for (int i = 0; i < numthreads_; ++i) { worker_queues_[i]->SignalForKill(); delete worker_queues_[i]; workers_[i].join(); @@ -36,10 +37,10 @@ class ThreadedEngine : public DAGEngine { Context exec_ctx, const vector &use_vars, const vector &mutate_vars) override { - shared_ptr opd( new OpDescr{exec_fun, exec_ctx, use_vars, mutate_vars}, - [this] (OpDescr* o) { this->OnDepsResolved(o); } ); - for( Variable v : use_vars ) { // read - VarDescr* vard = static_cast(v); // safe to cast here + shared_ptr opd(new OpDescr{exec_fun, exec_ctx, use_vars, mutate_vars}, + [this] (OpDescr* o) { this->OnDepsResolved(o); }); + for ( Variable v : use_vars ) { // read + VarDescr* vard = static_cast(v); // safe to cast here spin_lock(&vard->lock); if (vard->rw < 0) { vard->waitings.push(make_pair(opd, DepType::kRead)); @@ -48,8 +49,8 @@ class ThreadedEngine : public DAGEngine { } spin_unlock(&vard->lock); } - for( Variable v : mutate_vars ) { // write - VarDescr* vard = static_cast(v); // safe to cast here + for ( Variable v : mutate_vars ) { // write + VarDescr* vard = static_cast(v); // safe to cast here spin_lock(&vard->lock); if (vard->rw != 0) { vard->waitings.push(make_pair(opd, DepType::kWrite)); @@ -67,28 +68,29 @@ class ThreadedEngine : public DAGEngine { exec_fun(ctx); on_complete(); }, exec_ctx, use_vars, mutate_vars); } - void PushDelete(Op delete_fun, Variable var) override { - // TODO + void PushDelete(Op delete_fun, Context exec_ctx, Variable var) override { this->Push([delete_fun, var] (RunContext ctx) { delete_fun(ctx); - delete static_cast(var); - }, Context()/* TODO exec_ctx is missing?*/, {}, {var}); + delete static_cast(var); // TODO(minjie): use variable pool instead + }, exec_ctx, {}, {var}); } Variable NewVar() override { // in practice return a ptr to a cell // that have the info about the variable // use ptr directly instead of ID because this avoids an indirect mapping + // TODO(minjie): use variable pool instead VarDescr* vd = new VarDescr; vd->lock = SPINLOCK_INITIALIZER; vd->rw = 0; return vd; } void WaitForVar(Variable var) override { - // TODO + // TODO(minjie): tbd } void WaitForAll() override { - // TODO + // TODO(minjie): tbd } + private: enum class DepType { kRead = 0, @@ -103,23 +105,22 @@ class ThreadedEngine : public DAGEngine { }; struct VarDescr { spinlock lock; - int rw; // a semaphore-like count - // if rw > 0, the variable has several readers and the number - // means how many operators are currently reading it; - // if rw < 0, the varaible has one writer (should be -1) + int rw; // a semaphore-like count + // if rw > 0, the variable has several readers and the number + // means how many operators are currently reading it; + // if rw < 0, the varaible has one writer (should be -1) queue, DepType>> waitings; }; void TriggerWaiting(VarDescr* vard) { // ATTENTION: this function should be called with vard->lock held. CHECK(vard->rw == 0) << "the variable should be free during triggering"; - if(!vard->waitings.empty()) { + if (!vard->waitings.empty()) { // pop all reads first - while(vard->waitings.front().second == DepType::kRead) { + while (vard->waitings.front().second == DepType::kRead) { vard->waitings.pop(); ++vard->rw; } if (vard->rw == 0) { - // if the next one is a delete // pop the next write vard->waitings.pop(); vard->rw = -1; @@ -128,43 +129,45 @@ class ThreadedEngine : public DAGEngine { } void OnOpFinished(OpDescr* opd) { CHECK(opd) << "completing a nullptr op!"; - for(Variable v : opd->read_vars) { - VarDescr* vard = static_cast(v); // safe to cast here + for (Variable v : opd->read_vars) { + VarDescr* vard = static_cast(v); // safe to cast here spin_lock(&vard->lock); CHECK(vard->rw > 0) << "incorrect rw count (reader):" << vard->rw; - if(--vard->rw == 0) { + if (--vard->rw == 0) { TriggerWaiting(vard); } spin_unlock(&vard->lock); } - for(Variable v : opd->write_vars) { - VarDescr* vard = static_cast(v); // safe to cast here + for (Variable v : opd->write_vars) { + VarDescr* vard = static_cast(v); // safe to cast here spin_lock(&vard->lock); CHECK(vard->rw == -1) << "incorrect rw count (writer):" << vard->rw; vard->rw = 0; TriggerWaiting(vard); spin_unlock(&vard->lock); } - delete opd; // delete the operator + delete opd; // delete the operator } RunContext GetRunContext(const Context& ctx) { - // TODO + // TODO(minjie): get the correct runtime context return RunContext(); } void OnDepsResolved(OpDescr* opd) { static default_random_engine generator; - static uniform_int_distribution distribution(0, numthreads_); + static uniform_int_distribution distribution(0, numthreads_ - 1); int thrid = distribution(generator); + // LOG(INFO) << "schedule operator " << opd << " to thread #" << thrid; worker_queues_[thrid]->Push(opd); } void WorkerRoutine(int thrid) { OpDescr* opd = nullptr; - while(! worker_queues_[thrid]->Pop(opd)) { - LOG(INFO) << "worker thread #" << thrid << " got operator " << opd; + while (!worker_queues_[thrid]->Pop(opd)) { + // LOG(INFO) << "worker thread #" << thrid << " got operator " << opd; opd->op(GetRunContext(opd->exec_ctx), [this, opd] () { this->OnOpFinished(opd); }); opd = nullptr; } } + private: const int numthreads_; vector*> worker_queues_; diff --git a/src/symbol/symbol.cc b/src/symbol/symbol.cc new file mode 100644 index 000000000000..d44f63b52d95 --- /dev/null +++ b/src/symbol/symbol.cc @@ -0,0 +1,141 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file symbol.cc + * \brief symbol of mxnet + */ +#include +#include +#include + +namespace mxnet { + +Symbol::Node::Node(AtomicSymbol* sym, const std::string& name) + : sym_(sym), name_(name) { +} + +Symbol::Node::~Node() { + if (sym_) { + delete sym_; + } +} + +void Symbol::FindArgUsers() { + arg_users_.reset(new std::vector >); + // depth first traversing + std::vector > stk; + stk.push_back({head_.get(), 0}); + while (!stk.empty()) { + std::pair& back = stk.back(); + if (back.first->in_symbol_.size() == back.second) { + stk.pop_back(); + } else { + Node* next_level = back.first->in_symbol_[back.second].get(); + if (next_level->sym_) { + stk.push_back({next_level, 0}); + } else { // back uses next_level which is a placeholder + arg_users_->push_back({back.first, back.second}); + } + back.second += 1; + } + } +} + +Symbol Symbol::Copy() const { + Symbol s; + std::unordered_map > old_new; + std::vector stk; + stk.push_back(head_.get()); + // copy nodes + while (!stk.empty()) { + Node* back = stk.back(); + stk.pop_back(); + if (old_new.count(back) == 0) { + if (back->sym_) { + old_new[back] = std::make_shared(back->sym_->Copy(), back->name_); + } else { + old_new[back] = std::make_shared(nullptr, back->name_); + } + } + for (const std::shared_ptr& n : back->in_symbol_) { + if (old_new.count(n.get()) == 0) { + stk.push_back(n.get()); + } + } + } + // connect nodes + for (auto kv : old_new) { + for (const std::shared_ptr& n : kv.first->in_symbol_) { + kv.second->in_symbol_.push_back(old_new[n.get()]); + } + } + s.head_ = old_new[this->head_.get()]; + // copy arg_users_ + if (arg_users_) { + s.arg_users_.reset(new std::vector >); + std::transform(arg_users_->begin(), arg_users_->end(), std::back_inserter(*s.arg_users_), + [&old_new](const std::pair& n) -> std::pair { + return { old_new[n.first].get(), n.second }; + }); + } + return s; +} + +Symbol Symbol::operator () (const std::vector& args) const { + Symbol s = this->Copy(); + if (!s.arg_users_) { // if arg_users_ has not been populated + s.FindArgUsers(); + } + CHECK_LT(args.size(), s.arg_users_->size()) << "Too many args, requires " << s.arg_users_->size() + << " provided " << args.size(); + for (size_t i = 0; i < args.size(); ++i) { + const std::pair& arg_user = (*s.arg_users_)[i]; + arg_user.first->in_symbol_[arg_user.second] = args[i].head_; + CHECK_NE(args[i].index_, -1) << "Argument " << i << " is a tuple, scalar is required"; + arg_user.first->in_index_[arg_user.second] = args[i].index_; + } + s.arg_users_.reset(); + return s; +} + +Symbol Symbol::operator () (const std::unordered_map& kwargs) const { + Symbol s = this->Copy(); + if (!s.arg_users_) { // if arg_users_ has not been populated + s.FindArgUsers(); + } + CHECK_LT(kwargs.size(), s.arg_users_->size()) << "Too many args, requires " + << s.arg_users_->size() << " provided " << kwargs.size(); + for (size_t i = 0; i < s.arg_users_->size(); ++i) { + const std::pair& arg_user = (*s.arg_users_)[i]; + const std::string& name = arg_user.first->name_; + if (!(name == "") && kwargs.count(name) != 0) { + const Symbol& bind = kwargs.at(name); + arg_user.first->in_symbol_[arg_user.second] = bind.head_; + CHECK_NE(bind.index_, -1) << "Argument " << name << " is a tuple, scalar is required"; + arg_user.first->in_index_[arg_user.second] = bind.index_; + } + } + s.arg_users_.reset(); + // TODO(linmin): report error if kwargs contains non-existing keys + return s; +} + +Symbol Symbol::operator[] (int index) const { + CHECK_EQ(index_, -1) << "Current symbol can't be indexed because it returns a scalar."; + Symbol s = *this; + s.index_ = index; + return s; +} + +std::vector Symbol::ListArgs() { + std::vector ret; + if (!arg_users_) { + FindArgUsers(); + } + std::transform(arg_users_->begin(), arg_users_->end(), std::back_inserter(ret), + [&](const std::pair& n) -> std::string { + return n.first->in_symbol_[n.second]->name_; + }); + return ret; +} + +} // namespace mxnet diff --git a/test/test_threaded_engine.cc b/test/test_threaded_engine.cc index 3e47777e140d..fecd552d1b50 100644 --- a/test/test_threaded_engine.cc +++ b/test/test_threaded_engine.cc @@ -1,10 +1,43 @@ // Copyright (c) 2015 by Contributors -#include +#include +#include +#include + +#include "mxnet/dag_engine.h" using namespace std; using namespace mxnet; +void Foo(RunContext rctx, int i) { + cout << "say: " << i << endl; +} + int main() { DAGEngine* engine = DAGEngine::Get(); + Context exec_ctx; + + // Test #1 + cout << "============= Test #1 ==============" << endl; + vector vars; + for (int i = 0; i < 10; ++i) { + vars.push_back(engine->NewVar()); + } + for (int i = 0; i < 10; ++i) { + engine->Push([i] (RunContext rctx) { Foo(rctx, i); }, + exec_ctx, vars, {}); + } + + usleep(1000000); + + // Test #2 + cout << "============= Test #2 ==============" << endl; + for (int i = 0; i < 10; ++i) { + engine->Push([i] (RunContext rctx) { Foo(rctx, i); }, + exec_ctx, {}, vars); + } + + usleep(1000000); + + // Test #3 return 0; } From 7094ee15ca22fa093ef5cd60347a7ad668aad7ca Mon Sep 17 00:00:00 2001 From: winsty Date: Fri, 17 Jul 2015 15:43:28 +0800 Subject: [PATCH 3/3] rename --- Makefile | 14 +- include/mxnet/operator.h | 49 +++--- .../{narray_operator.h => static_operator.h} | 75 ++++----- src/narray_operator/narray_operator.cc | 127 --------------- src/operator/operator.cc | 149 ++++++++++++++---- .../activation_op-inl.h | 12 +- .../convolution_op-inl.h | 12 +- .../dropout_op-inl.h | 10 +- .../fully_connect_op-inl.h | 12 +- .../mshadow_op.h | 6 +- src/{operator => static_operator}/param.h | 6 +- .../pooling_op-inl.h | 12 +- .../reshape_op-inl.h | 10 +- .../static_operator-inl.h} | 14 +- src/static_operator/static_operator.cc | 44 ++++++ .../static_operator_common.h} | 12 +- .../static_operator_cpu.cc} | 6 +- .../static_operator_gpu.cu} | 6 +- 18 files changed, 288 insertions(+), 288 deletions(-) rename include/mxnet/{narray_operator.h => static_operator.h} (54%) delete mode 100644 src/narray_operator/narray_operator.cc rename src/{operator => static_operator}/activation_op-inl.h (88%) rename src/{operator => static_operator}/convolution_op-inl.h (97%) rename src/{operator => static_operator}/dropout_op-inl.h (92%) rename src/{operator => static_operator}/fully_connect_op-inl.h (93%) rename src/{operator => static_operator}/mshadow_op.h (93%) rename src/{operator => static_operator}/param.h (94%) rename src/{operator => static_operator}/pooling_op-inl.h (95%) rename src/{operator => static_operator}/reshape_op-inl.h (91%) rename src/{operator/operator-inl.h => static_operator/static_operator-inl.h} (76%) create mode 100644 src/static_operator/static_operator.cc rename src/{operator/operator_common.h => static_operator/static_operator_common.h} (84%) rename src/{operator/operator_cpu.cc => static_operator/static_operator_cpu.cc} (72%) rename src/{operator/operator_gpu.cu => static_operator/static_operator_gpu.cu} (74%) diff --git a/Makefile b/Makefile index 460bda8452df..3af4ad9f14d3 100644 --- a/Makefile +++ b/Makefile @@ -56,16 +56,16 @@ endif #BIN = test/test_threaded_engine test/api_registry_test BIN = test/api_registry_test -OBJ = storage.o narray_op_cpu.o operator.o operator_cpu.o +OBJ = storage.o narray_op_cpu.o static_operator.o static_operator_cpu.o # add threaded engine after it is done -OBJCXX11 = engine.o narray.o mxnet_api.o api_registry.o symbol.o narray_operator.o +OBJCXX11 = engine.o narray.o mxnet_api.o api_registry.o symbol.o operator.o CUOBJ = SLIB = api/libmxnet.so ALIB = api/libmxnet.a LIB_DEP = $(DMLC_CORE)/libdmlc.a ifeq ($(USE_CUDA), 1) - CUOBJ += narray_op_gpu.o operator_gpu.o + CUOBJ += narray_op_gpu.o static_operator_gpu.o endif .PHONY: clean all test lint doc @@ -81,13 +81,13 @@ engine.o: src/dag_engine/simple_engine.cc narray.o: src/narray/narray.cc narray_op_cpu.o: src/narray/narray_op_cpu.cc src/narray/narray_op-inl.h narray_op_gpu.o: src/narray/narray_op_gpu.cu src/narray/narray_op-inl.h -operator.o: src/operator/operator.cc -operator_cpu.o: src/operator/operator_cpu.cc -operator_gpu.o: src/operator/operator_gpu.cu +static_operator.o: src/static_operator/static_operator.cc +static_operator_cpu.o: src/static_operator/static_operator_cpu.cc +static_operator_gpu.o: src/static_operator/static_operator_gpu.cu symbol.o: src/symbol/symbol.cc api_registry.o: src/api_registry.cc mxnet_api.o: api/mxnet_api.cc -narray_operator.o: src/narray_operator/narray_operator.cc +operator.o: src/operator/operator.cc api/libmxnet.a: $(OBJ) $(OBJCXX11) $(CUOBJ) api/libmxnet.so: $(OBJ) $(OBJCXX11) $(CUOBJ) diff --git a/include/mxnet/operator.h b/include/mxnet/operator.h index a2796537f3a9..03e90173fe9e 100644 --- a/include/mxnet/operator.h +++ b/include/mxnet/operator.h @@ -1,7 +1,8 @@ /*! * Copyright (c) 2015 by Contributors * \file operator.h - * \brief static operator interface of mxnet + * \brief operator interface of mxnet + * \author Naiyan Wang */ #ifndef MXNET_OPERATOR_H_ #define MXNET_OPERATOR_H_ @@ -10,6 +11,9 @@ #include #include "./base.h" #include "./tensor_blob.h" +#include "./static_operator.h" +#include "./narray.h" +#include "./dag_engine.h" namespace mxnet { /*! @@ -19,33 +23,28 @@ namespace mxnet { * This interface relies on pre-allocated memory in TBlob, the caller need to set * the memory region in TBlob correctly before calling Forward and Backward * - * \sa TBlob, TShape + * \sa Operator */ class Operator { public: + Operator(StaticOperator* op, Context ctx); /*! * \brief get types of input argument of this oeprator * \return a vector corresponding to type of each argument * this order is same as the order of inputs in Forward, InferShape and Backward */ - virtual std::vector DescribeArgs() const { - // default most of layers only have one data argument - return std::vector(1, kDataArg); - } + virtual std::vector DescribeArgs() const; /*! * \brief describe property of op * \return a bit map in int */ - virtual int DescribeProperty() const { - // default most of layer only conatin internal state - return kContainInteralState; - } + virtual int DescribeProperty() const; /*! * \brief set param for the operator from string * \param name parameter name * \param val string for configuration */ - virtual void SetParam(const char *name, const char *val) {} + virtual void SetParam(const char *name, const char *val); /*! * \brief inter the shapes of outputs and unknown input arguments * \param in_shape the shape of input arguments of the operator @@ -60,7 +59,9 @@ class Operator { * InferShape will modify the vector to fill output TShape */ virtual void InferShape(std::vector *in_shape, - std::vector *out_shape) = 0; + std::vector *out_shape); + + virtual void SetContext(Context ctx); /*! * \brief perform a forward operation of operator, save the output to TBlob * \param opt option on Forward such as whether this is training phase @@ -71,8 +72,8 @@ class Operator { */ virtual void Forward(Option opt, RunContext ctx, - const std::vector &in_data, - const std::vector &out_data) = 0; + const std::vector &in_data, + const std::vector &out_data); /*! * \brief perform a backward operation of the operator to get the gradient * \param ctx runtime context @@ -85,17 +86,15 @@ class Operator { * \sa GradReqType */ virtual void Backward(RunContext ctx, - const std::vector &grad_next, - const std::vector &in_data, - const std::vector &out_grad, - const std::vector &req) = 0; - /*! - * \brief factory unction, create a new operator - * \param type the type of operator - * \param ctx the context device type of operator - * \return a pointer of Operator object - */ - static Operator *Create(const char *type, Context ctx); + const std::vector &grad_next, + const std::vector &in_data, + const std::vector &out_grad, + const std::vector &req); + + private: + /* \brief the static operator */ + StaticOperator* op; + Context global_ctx; }; } // namespace mxnet #endif // MXNET_OPERATOR_H_ diff --git a/include/mxnet/narray_operator.h b/include/mxnet/static_operator.h similarity index 54% rename from include/mxnet/narray_operator.h rename to include/mxnet/static_operator.h index e55ecc193453..b2f9e49af154 100644 --- a/include/mxnet/narray_operator.h +++ b/include/mxnet/static_operator.h @@ -1,53 +1,54 @@ /*! * Copyright (c) 2015 by Contributors - * \file narray_operator.h - * \brief narray operator interface of mxnet - * \author Naiyan Wang + * \file static_operator.h + * \brief static operator interface of mxnet */ -#ifndef MXNET_NARRAY_OPERATOR_H_ -#define MXNET_NARRAY_OPERATOR_H_ +#ifndef MXNET_STATIC_OPERATOR_H_ +#define MXNET_STATIC_OPERATOR_H_ // this file will be seen by cuda, no c++11 for now #include #include #include "./base.h" #include "./tensor_blob.h" -#include "./operator.h" -#include "./narray.h" -#include "./dag_engine.h" namespace mxnet { /*! - * \brief static operator interface (current interface have not yet todo with scheduler), - * operator is a stateful object that can be used to call forward and backprop + * \brief static StaticOperator interface (current interface have not yet todo with scheduler), + * StaticOperator is a stateful object that can be used to call forward and backprop * * This interface relies on pre-allocated memory in TBlob, the caller need to set * the memory region in TBlob correctly before calling Forward and Backward * - * \sa Operator + * \sa TBlob, TShape */ -class NArrayOperator { +class StaticOperator { public: - NArrayOperator(Operator* op, Context ctx); /*! * \brief get types of input argument of this oeprator * \return a vector corresponding to type of each argument * this order is same as the order of inputs in Forward, InferShape and Backward */ - virtual std::vector DescribeArgs() const; + virtual std::vector DescribeArgs() const { + // default most of layers only have one data argument + return std::vector(1, kDataArg); + } /*! * \brief describe property of op * \return a bit map in int */ - virtual int DescribeProperty() const; + virtual int DescribeProperty() const { + // default most of layer only conatin internal state + return kContainInteralState; + } /*! - * \brief set param for the operator from string + * \brief set param for the StaticOperator from string * \param name parameter name * \param val string for configuration */ - virtual void SetParam(const char *name, const char *val); + virtual void SetParam(const char *name, const char *val) {} /*! * \brief inter the shapes of outputs and unknown input arguments - * \param in_shape the shape of input arguments of the operator + * \param in_shape the shape of input arguments of the StaticOperator * this should be of same length as the vector returned by DescribeArgs * in_shape allows unknown elements, which are checked by shape.ndim() == 0. * For unknown shapes, InferShape will try to fill in the correct Shape in in_shape @@ -55,15 +56,13 @@ class NArrayOperator { * * common practice: set the shape of data input, and usually weight's shape can be infered * - * \param out_shape the shape of outputs of the operator + * \param out_shape the shape of outputs of the StaticOperator * InferShape will modify the vector to fill output TShape */ virtual void InferShape(std::vector *in_shape, - std::vector *out_shape); - - virtual void SetContext(Context ctx); + std::vector *out_shape) = 0; /*! - * \brief perform a forward operation of operator, save the output to TBlob + * \brief perform a forward operation of StaticOperator, save the output to TBlob * \param opt option on Forward such as whether this is training phase * \param ctx runtime context * \param in_data array of input data, it is const @@ -72,12 +71,12 @@ class NArrayOperator { */ virtual void Forward(Option opt, RunContext ctx, - const std::vector &in_data, - const std::vector &out_data); + const std::vector &in_data, + const std::vector &out_data) = 0; /*! - * \brief perform a backward operation of the operator to get the gradient + * \brief perform a backward operation of the StaticOperator to get the gradient * \param ctx runtime context - * \param grad_next the gradient value we get from output of the operator + * \param grad_next the gradient value we get from output of the StaticOperator * \param in_data the array of input data * \param out_grad array of output gradient, there could be three possible TBlob * in the each element in the array @@ -86,15 +85,17 @@ class NArrayOperator { * \sa GradReqType */ virtual void Backward(RunContext ctx, - const std::vector &grad_next, - const std::vector &in_data, - const std::vector &out_grad, - const std::vector &req); - -private: - /* \brief the static operator */ - Operator* op; - Context global_ctx; + const std::vector &grad_next, + const std::vector &in_data, + const std::vector &out_grad, + const std::vector &req) = 0; + /*! + * \brief factory unction, create a new StaticOperator + * \param type the type of StaticOperator + * \param ctx the context device type of StaticOperator + * \return a pointer of StaticOperator object + */ + static StaticOperator *Create(const char *type, Context ctx); }; } // namespace mxnet -#endif // MXNET_NARRAY_OPERATOR_H_ +#endif // MXNET_STATIC_OPERATOR_H_ diff --git a/src/narray_operator/narray_operator.cc b/src/narray_operator/narray_operator.cc deleted file mode 100644 index 7eaf934e1ec0..000000000000 --- a/src/narray_operator/narray_operator.cc +++ /dev/null @@ -1,127 +0,0 @@ -/*! - * Copyright (c) 2015 by Contributors - * \file narray_operator.cc - * \brief the implementation of narray operator - * \author Naiyan Wang - */ -#include - -namespace mxnet { - - NArrayOperator::NArrayOperator(Operator* op, Context ctx) { - this->op = op; - this->global_ctx = ctx; - } - /*! - * \brief get types of input argument of this oeprator - * \return a vector corresponding to type of each argument - * this order is same as the order of inputs in Forward, InferShape and Backward - */ - std::vector NArrayOperator::DescribeArgs() const { - // default most of layers only have one data argument - return op->DescribeArgs(); - } - /*! - * \brief describe property of op - * \return a bit map in int - */ - int NArrayOperator::DescribeProperty() const { - // default most of layer only conatin internal state - return op->DescribeProperty(); - } - /*! - * \brief set param for the operator from string - * \param name parameter name - * \param val string for configuration - */ - void NArrayOperator::SetParam(const char *name, const char *val) { - op->SetParam(name, val); - } - /*! - * \brief inter the shapes of outputs and unknown input arguments - * \param in_shape the shape of input arguments of the operator - * this should be of same length as the vector returned by DescribeArgs - * in_shape allows unknown elements, which are checked by shape.ndim() == 0. - * For unknown shapes, InferShape will try to fill in the correct Shape in in_shape - * For known shapes, InferShape will check shape consistency - * - * common practice: set the shape of data input, and usually weight's shape can be infered - * - * \param out_shape the shape of outputs of the operator - * InferShape will modify the vector to fill output TShape - */ - void NArrayOperator::InferShape(std::vector *in_shape, - std::vector *out_shape) { - op->InferShape(in_shape, out_shape); - } - /*! - * \brief perform a forward operation of operator, save the output to TBlob - * \param opt option on Forward such as whether this is training phase - * \param ctx runtime context - * \param in_data array of input data, it is const - * \param out_data array of output data, - * the space of TBlob in out_data must be pre-allocated with InferShape - */ - void NArrayOperator::Forward(Option opt, - RunContext ctx, - const std::vector &in_data, - const std::vector &out_data) { - std::vector used_var; - std::vector mutate_var; - std::vector in; - std::vector out; - for (size_t i = 0; i < in_data.size(); ++i) { - used_var.push_back(in_data[i].Var()); - in.push_back(in_data[i].data()); - } - for (size_t i = 0; i < out_data.size(); ++i) { - mutate_var.push_back(out_data[i].Var()); - out.push_back(out_data[i].data()); - } - DAGEngine::Get()->Push([this, opt, ctx, in, out](RunContext ctx) { - op->Forward(opt, ctx, in, out); - }, global_ctx, used_var, mutate_var); - } - /*! - * \brief perform a backward operation of the operator to get the gradient - * \param ctx runtime context - * \param grad_next the gradient value we get from output of the operator - * \param in_data the array of input data - * \param out_grad array of output gradient, there could be three possible TBlob - * in the each element in the array - * \param req request types of the gradient saving operation - * only inplace will change input data - * \sa GradReqType - */ - void NArrayOperator::Backward(RunContext ctx, - const std::vector &grad_next, - const std::vector &in_data, - const std::vector &out_grad, - const std::vector &req) { - std::vector used_var; - std::vector mutate_var; - std::vector grad_in; - std::vector grad_out; - std::vector data; - for (size_t i = 0; i < grad_next.size(); ++i) { - used_var.push_back(grad_next[i].Var()); - grad_in.push_back(grad_next[i].data()); - } - for (size_t i = 0; i < in_data.size(); ++i) { - used_var.push_back(in_data[i].Var()); - data.push_back(in_data[i].data()); - } - for (size_t i = 0; i < out_grad.size(); ++i) { - mutate_var.push_back(out_grad[i].Var()); - grad_out.push_back(out_grad[i].data()); - } - DAGEngine::Get()->Push([this, ctx, grad_in, grad_out, data, req](RunContext ctx) { - op->Backward(ctx, grad_in, data, grad_out, req); - }, global_ctx, used_var, mutate_var); - } - - void NArrayOperator::SetContext(Context ctx) { - this->global_ctx = ctx; - } - -} // namespace mxnet \ No newline at end of file diff --git a/src/operator/operator.cc b/src/operator/operator.cc index e8ec1c593a5d..ccfc640d3a3d 100644 --- a/src/operator/operator.cc +++ b/src/operator/operator.cc @@ -1,44 +1,127 @@ /*! * Copyright (c) 2015 by Contributors * \file operator.cc - * \brief - * \author: Bing Xu + * \brief the implementation of narray operator + * \author Naiyan Wang */ -#include -#include #include -#include -#include "./operator_common.h" namespace mxnet { -namespace op { -// declare the operator -template -Operator *CreateOperator(OpType type); + Operator::Operator(StaticOperator* op, Context ctx) { + this->op = op; + this->global_ctx = ctx; + } + /*! + * \brief get types of input argument of this oeprator + * \return a vector corresponding to type of each argument + * this order is same as the order of inputs in Forward, InferShape and Backward + */ + std::vector Operator::DescribeArgs() const { + // default most of layers only have one data argument + return op->DescribeArgs(); + } + /*! + * \brief describe property of op + * \return a bit map in int + */ + int Operator::DescribeProperty() const { + // default most of layer only conatin internal state + return op->DescribeProperty(); + } + /*! + * \brief set param for the operator from string + * \param name parameter name + * \param val string for configuration + */ + void Operator::SetParam(const char *name, const char *val) { + op->SetParam(name, val); + } + /*! + * \brief inter the shapes of outputs and unknown input arguments + * \param in_shape the shape of input arguments of the operator + * this should be of same length as the vector returned by DescribeArgs + * in_shape allows unknown elements, which are checked by shape.ndim() == 0. + * For unknown shapes, InferShape will try to fill in the correct Shape in in_shape + * For known shapes, InferShape will check shape consistency + * + * common practice: set the shape of data input, and usually weight's shape can be infered + * + * \param out_shape the shape of outputs of the operator + * InferShape will modify the vector to fill output TShape + */ + void Operator::InferShape(std::vector *in_shape, + std::vector *out_shape) { + op->InferShape(in_shape, out_shape); + } + /*! + * \brief perform a forward operation of operator, save the output to TBlob + * \param opt option on Forward such as whether this is training phase + * \param ctx runtime context + * \param in_data array of input data, it is const + * \param out_data array of output data, + * the space of TBlob in out_data must be pre-allocated with InferShape + */ + void Operator::Forward(Option opt, + RunContext ctx, + const std::vector &in_data, + const std::vector &out_data) { + std::vector used_var; + std::vector mutate_var; + std::vector in; + std::vector out; + for (size_t i = 0; i < in_data.size(); ++i) { + used_var.push_back(in_data[i].Var()); + in.push_back(in_data[i].data()); + } + for (size_t i = 0; i < out_data.size(); ++i) { + mutate_var.push_back(out_data[i].Var()); + out.push_back(out_data[i].data()); + } + DAGEngine::Get()->Push([this, opt, ctx, in, out](RunContext ctx) { + op->Forward(opt, ctx, in, out); + }, global_ctx, used_var, mutate_var); + } + /*! + * \brief perform a backward operation of the operator to get the gradient + * \param ctx runtime context + * \param grad_next the gradient value we get from output of the operator + * \param in_data the array of input data + * \param out_grad array of output gradient, there could be three possible TBlob + * in the each element in the array + * \param req request types of the gradient saving operation + * only inplace will change input data + * \sa GradReqType + */ + void Operator::Backward(RunContext ctx, + const std::vector &grad_next, + const std::vector &in_data, + const std::vector &out_grad, + const std::vector &req) { + std::vector used_var; + std::vector mutate_var; + std::vector grad_in; + std::vector grad_out; + std::vector data; + for (size_t i = 0; i < grad_next.size(); ++i) { + used_var.push_back(grad_next[i].Var()); + grad_in.push_back(grad_next[i].data()); + } + for (size_t i = 0; i < in_data.size(); ++i) { + used_var.push_back(in_data[i].Var()); + data.push_back(in_data[i].data()); + } + for (size_t i = 0; i < out_grad.size(); ++i) { + mutate_var.push_back(out_grad[i].Var()); + grad_out.push_back(out_grad[i].data()); + } + DAGEngine::Get()->Push([this, ctx, grad_in, grad_out, data, req](RunContext ctx) { + op->Backward(ctx, grad_in, data, grad_out, req); + }, global_ctx, used_var, mutate_var); + } -OpType GetOpType(const char *type) { - if (!strcmp(type, "relu")) return kReLU; - if (!strcmp(type, "fullc")) return kFullc; - LOG(FATAL) << "unknown operator type " << type; - return kReLU; -} -} // namespace op + void Operator::SetContext(Context ctx) { + this->global_ctx = ctx; + } -// implementing the context -Operator *Operator::Create(const char *type, - Context ctx) { - op::OpType otype = op::GetOpType(type); - if (ctx.dev_mask == cpu::kDevMask) { - return op::CreateOperator(otype); - } - if (ctx.dev_mask == gpu::kDevMask) { -#if MXNET_USE_CUDA - return op::CreateOperator(otype); -#else - LOG(FATAL) << "GPU is not enabled"; -#endif - } - return NULL; -} // namespace op } // namespace mxnet diff --git a/src/operator/activation_op-inl.h b/src/static_operator/activation_op-inl.h similarity index 88% rename from src/operator/activation_op-inl.h rename to src/static_operator/activation_op-inl.h index 0f46e7a0d994..b1ad0d090706 100644 --- a/src/operator/activation_op-inl.h +++ b/src/static_operator/activation_op-inl.h @@ -4,18 +4,18 @@ * \brief activation operator of mxnet */ -#ifndef MXNET_OPERATOR_ACTIVATION_OP_INL_H_ -#define MXNET_OPERATOR_ACTIVATION_OP_INL_H_ +#ifndef MXNET_STATIC_OPERATOR_ACTIVATION_OP_INL_H_ +#define MXNET_STATIC_OPERATOR_ACTIVATION_OP_INL_H_ #include -#include +#include #include -#include "./operator_common.h" +#include "./static_operator_common.h" namespace mxnet { namespace op { template -class ActivationOp : public Operator { +class ActivationOp : public StaticOperator { public: virtual void InferShape(std::vector *in_shape, std::vector *out_shape) { @@ -57,4 +57,4 @@ class ActivationOp : public Operator { } // namespace op } // namespace mxnet -#endif // MXNET_OPERATOR_ACTIVATION_OP_INL_H_ +#endif // MXNET_STATIC_OPERATOR_ACTIVATION_OP_INL_H_ diff --git a/src/operator/convolution_op-inl.h b/src/static_operator/convolution_op-inl.h similarity index 97% rename from src/operator/convolution_op-inl.h rename to src/static_operator/convolution_op-inl.h index 179f74eda8c6..0f7c5ccbb631 100644 --- a/src/operator/convolution_op-inl.h +++ b/src/static_operator/convolution_op-inl.h @@ -4,19 +4,19 @@ * \brief convolution op * \author Bing Xu */ -#ifndef MXNET_OPERATOR_CONVOLUTION_OP_INL_H_ -#define MXNET_OPERATOR_CONVOLUTION_OP_INL_H_ +#ifndef MXNET_STATIC_OPERATOR_CONVOLUTION_OP_INL_H_ +#define MXNET_STATIC_OPERATOR_CONVOLUTION_OP_INL_H_ -#include +#include #include #include -#include "./operator_common.h" +#include "./static_operator_common.h" #include "./param.h" namespace mxnet { namespace op { template -class ConvolutionOp : public Operator { +class ConvolutionOp : public StaticOperator { public: virtual std::vector DescribeArgs() const { ArgType ret[] = {kDataArg, kWeightArg, kBiasArg}; @@ -266,4 +266,4 @@ class ConvolutionOp : public Operator { }; // class ConvolutionOp } // namespace op } // namespace mxnet -#endif // MXNET_OPERATOR_CONVOLUTION_OP_INL_H_ +#endif // MXNET_STATIC_OPERATOR_CONVOLUTION_OP_INL_H_ diff --git a/src/operator/dropout_op-inl.h b/src/static_operator/dropout_op-inl.h similarity index 92% rename from src/operator/dropout_op-inl.h rename to src/static_operator/dropout_op-inl.h index b41266dc0dce..aba19ad3c88b 100644 --- a/src/operator/dropout_op-inl.h +++ b/src/static_operator/dropout_op-inl.h @@ -4,17 +4,17 @@ * \brief dropout operator * \author Bing Xu */ -#ifndef MXNET_OPERATOR_DROPOUT_OP_INL_H_ -#define MXNET_OPERATOR_DROPOUT_OP_INL_H_ +#ifndef MXNET_STATIC_OPERATOR_DROPOUT_OP_INL_H_ +#define MXNET_STATIC_OPERATOR_DROPOUT_OP_INL_H_ -#include +#include #include #include "./mshadow_op.h" namespace mxnet { namespace op { template -class DropoutOp : public Operator { +class DropoutOp : public StaticOperator { public: explicit DropoutOp(mshadow::Random *prnd) : prnd_(prnd), mask_used_(false) {} @@ -90,4 +90,4 @@ class DropoutOp : public Operator { }; // class DropoutOp } // namespace op } // namespace mxnet -#endif // MXNET_OPERATOR_DROPOUT_OP_INL_H_ +#endif // MXNET_STATIC_OPERATOR_DROPOUT_OP_INL_H_ diff --git a/src/operator/fully_connect_op-inl.h b/src/static_operator/fully_connect_op-inl.h similarity index 93% rename from src/operator/fully_connect_op-inl.h rename to src/static_operator/fully_connect_op-inl.h index 64b9983838dd..de0250f101bd 100644 --- a/src/operator/fully_connect_op-inl.h +++ b/src/static_operator/fully_connect_op-inl.h @@ -4,19 +4,19 @@ * \brief fully connect operator * \author Bing Xu */ -#ifndef MXNET_OPERATOR_FULLY_CONNECT_OP_INL_H_ -#define MXNET_OPERATOR_FULLY_CONNECT_OP_INL_H_ +#ifndef MXNET_STATIC_OPERATOR_FULLY_CONNECT_OP_INL_H_ +#define MXNET_STATIC_OPERATOR_FULLY_CONNECT_OP_INL_H_ #include -#include +#include #include -#include "./operator_common.h" +#include "./static_operator_common.h" #include "./param.h" namespace mxnet { namespace op { template -class FullyConnectOp : public Operator { +class FullyConnectOp : public StaticOperator { public: virtual std::vector DescribeArgs() const { ArgType ret[] = {kDataArg, kWeightArg, kBiasArg}; @@ -109,5 +109,5 @@ class FullyConnectOp : public Operator { } // namespace op } // namespace mxnet -#endif // MXNET_OPERATOR_FULLY_CONNECT_OP_INL_H_ +#endif // MXNET_STATIC_OPERATOR_FULLY_CONNECT_OP_INL_H_ diff --git a/src/operator/mshadow_op.h b/src/static_operator/mshadow_op.h similarity index 93% rename from src/operator/mshadow_op.h rename to src/static_operator/mshadow_op.h index e0c1ccb7d890..2954b1f81a48 100644 --- a/src/operator/mshadow_op.h +++ b/src/static_operator/mshadow_op.h @@ -4,8 +4,8 @@ * \brief extra mshadow operation for mxnet * \author Bing Xu */ -#ifndef MXNET_OPERATOR_MSHADOW_OP_H_ -#define MXNET_OPERATOR_MSHADOW_OP_H_ +#ifndef MXNET_STATIC_OPERATOR_MSHADOW_OP_H_ +#define MXNET_STATIC_OPERATOR_MSHADOW_OP_H_ #include #include @@ -102,5 +102,5 @@ struct square_root { } // namespace op } // namespace mxnet -#endif // MXNET_OPERATOR_MSHADOW_OP_H_ +#endif // MXNET_STATIC_OPERATOR_MSHADOW_OP_H_ diff --git a/src/operator/param.h b/src/static_operator/param.h similarity index 94% rename from src/operator/param.h rename to src/static_operator/param.h index 0d8016983c5a..c2829aced8ae 100644 --- a/src/operator/param.h +++ b/src/static_operator/param.h @@ -4,8 +4,8 @@ * \brief operator params * \author Bing Xu */ -#ifndef MXNET_OPERATOR_PARAM_H_ -#define MXNET_OPERATOR_PARAM_H_ +#ifndef MXNET_STATIC_OPERATOR_PARAM_H_ +#define MXNET_STATIC_OPERATOR_PARAM_H_ namespace mxnet { namespace op { @@ -68,6 +68,6 @@ struct Param { } // namespace op } // namespace mxnet -#endif // MXNET_OPERATOR_PARAM_H_ +#endif // MXNET_STATIC_OPERATOR_PARAM_H_ diff --git a/src/operator/pooling_op-inl.h b/src/static_operator/pooling_op-inl.h similarity index 95% rename from src/operator/pooling_op-inl.h rename to src/static_operator/pooling_op-inl.h index 2721c5c1930c..e4bf344f7e5a 100644 --- a/src/operator/pooling_op-inl.h +++ b/src/static_operator/pooling_op-inl.h @@ -4,20 +4,20 @@ * \brief pooling operator * \author Bing Xu */ -#ifndef MXNET_OPERATOR_POOLING_OP_INL_H_ -#define MXNET_OPERATOR_POOLING_OP_INL_H_ +#ifndef MXNET_STATIC_OPERATOR_POOLING_OP_INL_H_ +#define MXNET_STATIC_OPERATOR_POOLING_OP_INL_H_ -#include +#include #include #include #include "./param.h" -#include "./operator_common.h" +#include "./static_operator_common.h" namespace mxnet { namespace op { template -class PoolingOp : public Operator { +class PoolingOp : public StaticOperator { public: virtual void SetParam(const char *name, const char *val) { param_.SetParam(name, val); @@ -149,4 +149,4 @@ class PoolingOp : public Operator { } // namespace op } // namespace mxnet -#endif // MXNET_OPERATOR_POOLING_OP_INL_H_ +#endif // MXNET_STATIC_OPERATOR_POOLING_OP_INL_H_ diff --git a/src/operator/reshape_op-inl.h b/src/static_operator/reshape_op-inl.h similarity index 91% rename from src/operator/reshape_op-inl.h rename to src/static_operator/reshape_op-inl.h index a8377ad2b65c..eb05a460573d 100644 --- a/src/operator/reshape_op-inl.h +++ b/src/static_operator/reshape_op-inl.h @@ -4,16 +4,16 @@ * \brief * \author Bing Xu */ -#ifndef MXNET_OPERATOR_RESHAPE_OP_INL_H_ -#define MXNET_OPERATOR_RESHAPE_OP_INL_H_ +#ifndef MXNET_STATIC_OPERATOR_RESHAPE_OP_INL_H_ +#define MXNET_STATIC_OPERATOR_RESHAPE_OP_INL_H_ -#include +#include #include namespace mxnet { namespace op { template -class ReshapeOp : public Operator { +class ReshapeOp : public StaticOperator { public: virtual void SetParam(const char *name, const char *val) { if (!strcmp(name, "out_ch")) oshape_[1] = atoi(val); @@ -72,4 +72,4 @@ class ReshapeOp : public Operator { } // namespace op } // namespace mxnet -#endif // MXNET_OPERATOR_RESHAPE_OP_INL_H_ +#endif // MXNET_STATIC_OPERATOR_RESHAPE_OP_INL_H_ diff --git a/src/operator/operator-inl.h b/src/static_operator/static_operator-inl.h similarity index 76% rename from src/operator/operator-inl.h rename to src/static_operator/static_operator-inl.h index dc3b7cced56a..d52f92571412 100644 --- a/src/operator/operator-inl.h +++ b/src/static_operator/static_operator-inl.h @@ -1,14 +1,14 @@ /*! * Copyright (c) 2015 by Contributors - * \file operator-inl.h - * \brief device invarient code to create operators + * \file static_operator-inl.h + * \brief static device invarient code to create operators * \author Bing Xu */ -#ifndef MXNET_OPERATOR_OPERATOR_INL_H_ -#define MXNET_OPERATOR_OPERATOR_INL_H_ +#ifndef MXNET_STATIC_OPERATOR_STATIC_OPERATOR_INL_H_ +#define MXNET_STATIC_OPERATOR_STATIC_OPERATOR_INL_H_ #include #include -#include +#include #include "./mshadow_op.h" #include "./activation_op-inl.h" #include "./fully_connect_op-inl.h" @@ -25,7 +25,7 @@ namespace op { * \tparam xpu the device type we are at */ template -inline Operator *CreateOperator_(OpType type, mshadow::Random *prnd) { +inline StaticOperator *CreateOperator_(OpType type, mshadow::Random *prnd) { switch (type) { case kReLU: return new ActivationOp(); @@ -49,4 +49,4 @@ inline Operator *CreateOperator_(OpType type, mshadow::Random *prnd) { } } // namespace op } // namespace mxnet -#endif // MXNET_OPERATOR_OPERATOR_INL_H_ +#endif // MXNET_STATIC_OPERATOR_STATIC_OPERATOR_INL_H_ diff --git a/src/static_operator/static_operator.cc b/src/static_operator/static_operator.cc new file mode 100644 index 000000000000..67464fb394b6 --- /dev/null +++ b/src/static_operator/static_operator.cc @@ -0,0 +1,44 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file static_operator.cc + * \brief + * \author: Bing Xu + */ +#include +#include +#include +#include +#include "./static_operator_common.h" + +namespace mxnet { +namespace op { +// declare the operator +template +StaticOperator *CreateOperator(OpType type); + + +OpType GetOpType(const char *type) { + if (!strcmp(type, "relu")) return kReLU; + if (!strcmp(type, "fullc")) return kFullc; + LOG(FATAL) << "unknown operator type " << type; + return kReLU; +} +} // namespace op + +// implementing the context +StaticOperator *StaticOperator::Create(const char *type, + Context ctx) { + op::OpType otype = op::GetOpType(type); + if (ctx.dev_mask == cpu::kDevMask) { + return op::CreateOperator(otype); + } + if (ctx.dev_mask == gpu::kDevMask) { +#if MXNET_USE_CUDA + return op::CreateOperator(otype); +#else + LOG(FATAL) << "GPU is not enabled"; +#endif + } + return NULL; +} // namespace op +} // namespace mxnet diff --git a/src/operator/operator_common.h b/src/static_operator/static_operator_common.h similarity index 84% rename from src/operator/operator_common.h rename to src/static_operator/static_operator_common.h index f87ffdbb7efb..0d1553703200 100644 --- a/src/operator/operator_common.h +++ b/src/static_operator/static_operator_common.h @@ -1,16 +1,16 @@ /*! * Copyright (c) 2015 by Contributors - * \file operator_common.h + * \file static_operator_common.h * \brief common internal header of most operators * this header includes utility functions operator can use * common type definitions * \author Bing Xu */ -#ifndef MXNET_OPERATOR_OPERATOR_COMMON_H_ -#define MXNET_OPERATOR_OPERATOR_COMMON_H_ +#ifndef MXNET_STATIC_OPERATOR_STATIC_OPERATOR_COMMON_H_ +#define MXNET_STATIC_OPERATOR_STATIC_OPERATOR_COMMON_H_ #include -#include +#include #include namespace mxnet { namespace op { @@ -67,7 +67,7 @@ enum OpType { * \param type the type of operator */ template -Operator *CreateOperator(OpType type); +StaticOperator *CreateOperator(OpType type); } // namespace op } // namespace mxnet -#endif // MXNET_OPERATOR_OPERATOR_COMMON_H_ +#endif // MXNET_STATIC_OPERATOR_STATIC_OPERATOR_COMMON_H_ diff --git a/src/operator/operator_cpu.cc b/src/static_operator/static_operator_cpu.cc similarity index 72% rename from src/operator/operator_cpu.cc rename to src/static_operator/static_operator_cpu.cc index b4545266e53b..5b6ea861213b 100644 --- a/src/operator/operator_cpu.cc +++ b/src/static_operator/static_operator_cpu.cc @@ -1,10 +1,10 @@ /*! * Copyright (c) 2015 by Contributors - * \file operator_cpu.cc + * \file static_operator_cpu.cc * \brief CPU specialization of operator codes * \author Bing Xu */ -#include "./operator-inl.h" +#include "./static_operator-inl.h" namespace mxnet { namespace op { @@ -12,7 +12,7 @@ namespace op { mshadow::Random prnd_cpu(0); template<> -Operator *CreateOperator(OpType type) { +StaticOperator *CreateOperator(OpType type) { return CreateOperator_(type, &prnd_cpu); } diff --git a/src/operator/operator_gpu.cu b/src/static_operator/static_operator_gpu.cu similarity index 74% rename from src/operator/operator_gpu.cu rename to src/static_operator/static_operator_gpu.cu index f745d818b1f0..580fe65d630d 100644 --- a/src/operator/operator_gpu.cu +++ b/src/static_operator/static_operator_gpu.cu @@ -1,12 +1,12 @@ /*! * Copyright (c) 2015 by Contributors - * \file operator_gpu.cu + * \file static_operator_gpu.cu * \brief GPU specialization of operator code * \author Bing Xu */ #include #include -#include "operator-inl.h" +#include "static_operator-inl.h" namespace mxnet { namespace op { @@ -14,7 +14,7 @@ namespace op { mshadow::Random prnd_gpu(0); template<> -Operator *CreateOperator(OpType type) { +StaticOperator *CreateOperator(OpType type) { return CreateOperator_(type, &prnd_gpu); }