diff --git a/Makefile b/Makefile index a6da0c4206b0..581674c784a2 100644 --- a/Makefile +++ b/Makefile @@ -58,14 +58,14 @@ endif BIN = test/api_registry_test OBJ = storage.o narray_op_cpu.o # add threaded engine after it is done -OBJCXX11 = engine.o narray.o c_api.o registry.o symbol.o fully_connect_op_cpu.o +OBJCXX11 = engine.o narray.o c_api.o registry.o symbol.o fully_connected_cpu.o static_graph.o CUOBJ = SLIB = lib/libmxnet.so ALIB = lib/libmxnet.a LIB_DEP = $(DMLC_CORE)/libdmlc.a ifeq ($(USE_CUDA), 1) - CUOBJ += narray_op_gpu.o fully_connect_op_gpu.o + CUOBJ += narray_op_gpu.o fully_connected_gpu.o endif .PHONY: clean all test lint doc @@ -77,20 +77,16 @@ $(DMLC_CORE)/libdmlc.a: storage.o: src/storage/storage.cc engine.o: src/dag_engine/simple_engine.cc -#engine.o: src/dag_engine/threaded_engine.cc src/common/concurrent_blocking_queue.h src/common/spin_lock.h narray.o: src/narray/narray.cc narray_op_cpu.o: src/narray/narray_op_cpu.cc src/narray/narray_op-inl.h narray_op_gpu.o: src/narray/narray_op_gpu.cu src/narray/narray_op-inl.h -static_operator.o: src/operator/static_operator/static_operator.cc -static_operator_cpu.o: src/operator/static_operator/static_operator_cpu.cc -static_operator_gpu.o: src/operator/static_operator/static_operator_gpu.cu symbol.o: src/symbol/symbol.cc static_graph.o : src/symbol/static_graph.cc registry.o: src/registry.cc c_api.o: src/c_api.cc operator.o: src/operator/static_operator_wrapper.cc -fully_connect_op_cpu.o: src/operator/static_operator/fully_connect_op.cc -fully_connect_op_gpu.o: src/operator/static_operator/fully_connect_op.cu +fully_connected_cpu.o: src/operator/fully_connected.cc +fully_connected_gpu.o: src/operator/fully_connected.cu lib/libmxnet.a: $(OBJ) $(OBJCXX11) $(CUOBJ) diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h index bb718b6f9fdb..a9a15c4a8007 100644 --- a/include/mxnet/c_api.h +++ b/include/mxnet/c_api.h @@ -225,7 +225,7 @@ MXNET_DLL int MXSymbolListAtomicSymbolCreators(mx_uint *out_size, MXNET_DLL int MXSymbolGetAtomicSymbolName(AtomicSymbolCreator creator, const char **out); /*! - * \brief create Symbol by wrapping AtomicSymbol + * \brief Create an AtomicSymbol. * \param creator the AtomicSymbolCreator * \param num_param the number of parameters * \param keys the keys to the params @@ -233,11 +233,11 @@ MXNET_DLL int MXSymbolGetAtomicSymbolName(AtomicSymbolCreator creator, * \param out pointer to the created symbol handle * \return 0 when success, -1 when failure happens */ -MXNET_DLL int MXSymbolCreateFromAtomicSymbol(AtomicSymbolCreator creator, - int num_param, - const char **keys, - const char **vals, - SymbolHandle *out); +MXNET_DLL int MXSymbolCreateAtomicSymbol(AtomicSymbolCreator creator, + int num_param, + const char **keys, + const char **vals, + SymbolHandle *out); /*! * \brief Create a Variable Symbol. * \param name name of the variable diff --git a/include/mxnet/narray.h b/include/mxnet/narray.h index 1c4098ce7ac8..99829bda92da 100644 --- a/include/mxnet/narray.h +++ b/include/mxnet/narray.h @@ -5,6 +5,7 @@ */ #ifndef MXNET_NARRAY_H_ #define MXNET_NARRAY_H_ + #include #include #include diff --git a/include/mxnet/operator.h b/include/mxnet/operator.h index d6b9d865e0c5..60284c1a5fa3 100644 --- a/include/mxnet/operator.h +++ b/include/mxnet/operator.h @@ -1,20 +1,17 @@ /*! * Copyright (c) 2015 by Contributors * \file operator.h - * \brief operator interface of mxnet + * \brief Operator interface of mxnet. * \author Naiyan Wang */ #ifndef MXNET_OPERATOR_H_ #define MXNET_OPERATOR_H_ -// this file will be seen by cuda, no c++11 for now + #include #include +#include +#include #include "./base.h" -#if DMLC_USE_CXX11 -#include "./narray.h" -#include "./dag_engine.h" -#endif -#include "./symbolic.h" namespace mxnet { /*! \brief option to pass into the forward function */ @@ -38,21 +35,25 @@ enum OpReqType { /*! \brief add to the provided space */ kAddTo }; + /*! - * \brief StaticOperator interface - * StaticOperator is a stateful object that can be used to call forward and backprop - * + * \brief Operator interface. + * Operator defins basic operation unit of optimized computation graph in mxnet. * This interface relies on pre-allocated memory in TBlob, the caller need to set - * the memory region in TBlob correctly before calling Forward and Backward + * the memory region in TBlob correctly before calling Forward and Backward. + * + * Operator is generated by OperatorProperty. + * To add new operator(aka. layers of neural nets) to mxnet, developer need to create + * a new OperatorProperty and its corresponding Operator. * - * \sa TBlob, TShape + * \sa TBlob, TShape, OperatorProperty */ -class StaticOperator { +class Operator { public: /*! \brief destructor */ - virtual ~StaticOperator() {} + virtual ~Operator() {} /*! - * \brief perform a forward operation of StaticOperator, save the output to TBlob. + * \brief perform a forward operation of Operator, save the output to TBlob. * \param opt option on Forward such as whether this is training phase. * \param ctx runtime context * \param in_data array of input data, it is const @@ -69,7 +70,7 @@ class StaticOperator { /*! * \brief Perform a backward Operation, write gradient to the in_grad. * \param ctx runtime context - * \param out_grad the gradient value we get from output of the StaticOperator + * \param out_grad the gradient value we get from output of the Operator * \param in_data the array of input data. * \param out_data the array of output data. * \param req request types of the saving operation, can be all types. @@ -85,53 +86,198 @@ class StaticOperator { }; #if DMLC_USE_CXX11 - const std::vector &out_grad, - const std::vector &in_data, - const std::vector &out_data, - const std::vector &req, - const std::vector &in_grad) = 0; -}; - -#if DMLC_USE_CXX11 +// OperatorProperty allows C++11, while Operator do not rely on it. /*! - * \brief Operator interface. - * Operator is an object can have Forward and Backward function. + * \brief OperatorProperty is a object that stores all information about Operator. + * It also contains method to generate context(device) specific operators. * - * It can be created from + * It also contains various functions that can be optimally overriden to + * provide optimization chance for computation engine. */ -class Operator { +class OperatorProperty { public: - /*! \brief destructor */ - virtual ~Operator() {} /*! - * \brief Perform a Forward operation of Operator - * After this operation, user can get the result by using function head. + * \brief virtual destructor + */ + virtual ~OperatorProperty() {} + /*! + * \brief Get input arguments of the Operator. + * \return vector of arguments. + */ + virtual std::vector ListArguments() const { + return {"data"}; + } + /*! + * \brief Get name of return values of Operator + * \return name of return values. */ - virtual void Forward() = 0; + virtual std::vector ListReturns() const { + return {"output"}; + } + /*! \return number of outputs of the Operator */ + virtual int NumReturns() const { + return 1; + } /*! - * \brief Perform a Backward operation of the Operator. - * This must be called after Forward. - * After this operation, NArrays specified by grad_in_args_store will be updated accordingly. + * \brief Set the parameters of the Operator. + * \param name parameter name + * \param val string for the configuration */ - virtual void Backward() = 0; - /*! \return get array of heads in the operator */ - virtual const std::vector &head() const = 0; + virtual void SetParam(const char *name, const char *val) {} /*! - * \brief Create an operator by bind symbol with context and arguments. - * If user do not want to compute the gradients of i-th argument, grad_req_type[i] can be kNullOp. + * \brief infer the shapes of outputs and unknown input arguments + * \param in_shape the shape of input arguments of the operator + * this should be of same length as the vector returned by DescribeArgs + * in_shape allows unknown elements, which are checked by shape.ndim() == 0. + * For unknown shapes, InferShape will try to fill in the correct Shape in in_shape + * For known shapes, InferShape will check shape consistency * - * \param ctx the context of binding. - * \param symbol the symbol that specifies the output of Forward pass. - * \param in_args the NArray that stores the input arguments to the symbol. - * \param grad_in_args_store NArray that is used to store the gradient output of the input arguments. - * \param grad_req_type requirment type of gradient saving. Can only be in {kNullOp, kAddTo, kWriteTo}. - */ - static Operator *Bind(Symbol symbol, - Context ctx, - const std::vector &in_args, - const std::vector &grad_in_args_store, - const std::vector &grad_req_type); -}; // class operator + * common practice: set the shape of data input, and usually weight's shape can be infered + * + * \param out_shape the shape of outputs of the operator + * InferShape will modify the vector to fill output TShape + * \return if the shape inference is successful, return true, else return false. + */ + virtual bool InferShape(std::vector *in_shape, + std::vector *out_shape) const = 0; + /*! + * \brief Copy this OperatorProperty. + * \return a pointer to the copied OperatorProperty + */ + virtual OperatorProperty* Copy() const = 0; + /*! + * \brief Create a Operator on specific context + */ + virtual Operator* CreateOperator(Context ctx) const = 0; + /*! + * \brief return the type string of the Operator + * subclasses override this function. + */ + virtual std::string TypeString() const = 0; + /*! + * \brief Declare the input requirement of Backward pass. + * + * Only the returned list of variables will be used in Backward. + * This function is used for memory optimization. + * It is adviced to override and only return what is actually needed. + * If this function is not overriden, all the variables will be valid in Backward. + * + * \code + * // The following code declares Backward need out_grad[0], in_data[0],in_data[1] + * vector BackwardInputs(const vector &out_grad, + * const vector &in_data, + * const vector &out_data) const { + * return {out_grad[0], in_data[0], in_data[1]}; + * } + * \endcode + * \param out_grad gradient of outputs in backward pass. + * \param in_data the input data in forward pass. + * \param out_data the output data in forward pass. + * \return an integer vector indicating the input requirments + * \sa BackwardInputs + */ + virtual std::vector DeclareBackwardDependency( + const std::vector &out_grad, + const std::vector &in_data, + const std::vector &out_data) const { + // By default requires to see all the things. + // remember to override this function to get a better performance. + std::vector ret = out_grad; + ret.insert(ret.end(), in_data.begin(), in_data.end()); + ret.insert(ret.end(), out_data.begin(), out_data.end()); + return ret; + } + /*! + * \brief Get possible forward inplace options. + * This function enables optimization to reuse memory of inputs in output. + * Only override when necessary, by default in-place is disabled. + * + * \code + * // The following code says out_data[0] can share data with in_data[0] + * vector > ForwardInplaceOption(const vector &in_data, + * const vector &out_data) const { + * return {{out_data[0], in_data[0]}}; + * } + * \endcode + * \return list of pair of integers taken from the inputs vector, + * indicating possible in place operations. + */ + virtual std::vector > ForwardInplaceOption( + const std::vector &in_data, + const std::vector &out_data) const { + return std::vector >(); + } + /*! + * \brief Get possible backward inplace options. + * This function enables optimization to reuse memory of inputs in output. + * Only override when necessary, by default in-place is disabled. + * + * \code + * // The following code says in_grad[0] can share data with in_data[0] + * vector > BackwardInplaceOption( + * const std::vector &out_grad, + * const std::vector &in_data, + * const std::vector &out_data, + * const std::vector &in_grad) const { + * return {in_grad[0], in_data[0]}}; + * } + * \endcode + * \return list of pair of integers taken from the inputs vector, + * indicating possible in place operations. + */ + virtual std::vector > BackwardInplaceOption( + const std::vector &out_grad, + const std::vector &in_data, + const std::vector &out_data, + const std::vector &in_grad) const { + return std::vector >(); + } + /*! + * \brief Get Backward Input Dependency for generic types of data. + * Normally T can be pointer of Symbol::DataEntry, or NArray. + * This function will select the result list of T according to DeclareBackwardDependency. + * + * \param in_data the input data in forward pass. + * \param out_data the output data in forward pass. + * \param out_grad gradient of outputs in backward pass. + * \tparam T the generic type parameter. + * \return vector of inputs the Backward Operation depends on. + * \sa DeclareBackwardDependency + */ + template + inline std::vector BackwardInputs(const std::vector &in_data, + const std::vector &out_data, + const std::vector &out_grad) const { + int cnt = 0; + std::vector all_vec; + std::vector in_data_idx, out_data_idx, out_grad_idx; + for (size_t i = 0; i < in_data.size(); ++i) { + in_data_idx.push_back(cnt++); + all_vec.push_back(in_data[i]); + } + for (size_t i = 0; i < out_data.size(); ++i) { + out_data_idx.push_back(cnt++); + all_vec.push_back(out_data[i]); + } + for (size_t i = 0; i < out_grad.size(); ++i) { + out_grad_idx.push_back(cnt++); + all_vec.push_back(out_data[i]); + } + std::vector ret_idx = this->DeclareBackwardDependency( + in_data_idx, out_data_idx, out_grad_idx); + std::vector ret; + for (size_t i = 0; i < ret_idx.size(); ++i) { + ret.push_back(all_vec[ret_idx[i]]); + } + return ret; + } + /*! + * \brief create OperatorProperty + * \param type_name the type string of the OperatorProperty + * \return a new constructed OperatorProperty + */ + static OperatorProperty *Create(const char* type_name); +}; #endif } // namespace mxnet #endif // MXNET_OPERATOR_H_ diff --git a/include/mxnet/registry.h b/include/mxnet/registry.h index 04a3eb1abb51..ddc0a3ca22a0 100644 --- a/include/mxnet/registry.h +++ b/include/mxnet/registry.h @@ -10,9 +10,10 @@ #include #include #include +#include #include "./base.h" #include "./narray.h" -#include "./symbolic.h" +#include "./operator.h" namespace mxnet { @@ -63,9 +64,6 @@ class Registry { } }; -/*! NArrayFunctionEntry requires c++11 */ -#if DMLC_USE_CXX11 -#include /*! \brief mask information on how functions can be exposed */ enum FunctionTypeMask { /*! \brief all the use_vars should go before scalar */ @@ -216,46 +214,45 @@ struct NArrayFunctionEntry { #define REGISTER_NARRAY_FUN(name) \ static auto __ ## name ## _narray_fun__ = \ ::mxnet::Registry::Get()->Register("" # name) -#endif // DMLC_USE_CXX11 -class Symbol; -/*! \brief AtomicSymbolEntry to register */ -struct AtomicSymbolEntry { + +/*! \brief OperatorPropertyEntry to register */ +struct OperatorPropertyEntry { /*! \brief typedef Creator function */ - typedef AtomicSymbol*(*Creator)(); - /*! \brief if AtomicSymbol use param */ + typedef OperatorProperty*(*Creator)(); + /*! \brief if OperatorProperty use param */ bool use_param; /*! \brief name of the entry */ std::string name; - /*! \brief function body to create AtomicSymbol */ + /*! \brief function body to create OperatorProperty */ Creator body; /*! \brief constructor */ - explicit AtomicSymbolEntry(const std::string& name) + explicit OperatorPropertyEntry(const std::string& name) : use_param(true), name(name), body(NULL) {} /*! * \brief set the function body */ - inline AtomicSymbolEntry &set_body(Creator body) { + inline OperatorPropertyEntry &set_body(Creator body) { this->body = body; return *this; } /*! * \brief invoke the function - * \return the created AtomicSymbol + * \return the created OperatorProperty */ - inline AtomicSymbol* operator () () const { + inline OperatorProperty* operator () () const { return body(); } private: /*! \brief disable copy constructor */ - AtomicSymbolEntry(const AtomicSymbolEntry& other) {} + OperatorPropertyEntry(const OperatorPropertyEntry& other) {} /*! \brief disable assignment operator */ - const AtomicSymbolEntry& operator = (const AtomicSymbolEntry& other) { return *this; } + const OperatorPropertyEntry& operator = (const OperatorPropertyEntry& other) { return *this; } }; /*! - * \brief macro to register AtomicSymbol to AtomicSymbolFactory + * \brief macro to register OperatorProperty to OperatorPropertyFactory * * Example: the following code is example to register aplus * \code @@ -265,13 +262,13 @@ struct AtomicSymbolEntry { * * \endcode */ -#define REGISTER_ATOMIC_SYMBOL(name, AtomicSymbolType) \ - ::mxnet::AtomicSymbol* __make_ ## AtomicSymbolType ## __() { \ - return new AtomicSymbolType; \ +#define REGISTER_OP_PROPERTY(name, OperatorPropertyType) \ + ::mxnet::OperatorProperty* __make_ ## OperatorPropertyType ## __() { \ + return new OperatorPropertyType; \ } \ - static ::mxnet::AtomicSymbolEntry& __ ## name ## _atomic_symbol__ = \ - ::mxnet::Registry< ::mxnet::AtomicSymbolEntry >::Get()->Register("" # name) \ - .set_body(__make_ ## AtomicSymbolType ## __) + static ::mxnet::OperatorPropertyEntry& __ ## name ## _atomic_symbol__ = \ + ::mxnet::Registry< ::mxnet::OperatorPropertyEntry >::Get()->Register("" # name) \ + .set_body(__make_ ## OperatorPropertyType ## __) } // namespace mxnet #endif // MXNET_REGISTRY_H_ diff --git a/include/mxnet/symbolic.h b/include/mxnet/symbolic.h index c5a92fb07e35..dc00f5a33fb6 100644 --- a/include/mxnet/symbolic.h +++ b/include/mxnet/symbolic.h @@ -1,216 +1,29 @@ /*! * Copyright (c) 2015 by Contributors * \file symbolic.h - * \brief - * \author Bing Xu + * \brief Symbolic interface of mxnet. + * \author Min Lin, Bing Xu */ - #ifndef MXNET_SYMBOLIC_H_ #define MXNET_SYMBOLIC_H_ +#include #include #include #include -#include #include -#if DMLC_USE_CXX11 #include #include -#endif #include "./base.h" +#include "./narray.h" +#include "./operator.h" -namespace mxnet { -// forward declare StaticOperator -class StaticOperator; -#if DMLC_USE_CXX11 -/*! - * \brief AtomicSymbol is the base class of all atomic symbols. - * This is not meant to be used by user, it should be wrapped in Symbol, so that the same instance - * of AtomicSymbol can be shared in the graphs of different Symbols - */ -class AtomicSymbol { - public: - /*! - * \brief virtual destructor - */ - virtual ~AtomicSymbol() {} - /*! \brief get the descriptions of inputs for this symbol */ - virtual std::vector ListArguments() const { - // default implementation returns "data" - return std::vector(1, std::string("data")); - } - /*! \brief get the descriptions of outputs for this symbol */ - virtual std::vector ListReturns() const { - // default implementation returns "output" - return std::vector(1, std::string("output")); - } - /*! \brief number of outputs of the symbol */ - virtual int NumReturns() const { - return 1; - } - /*! - * \brief set param for the symbol from string - * \param name parameter name - * \param val string for the configuration - */ - virtual void SetParam(const char *name, const char *val) {} - /*! - * \brief infer the shapes of outputs and unknown input arguments - * \param in_shape the shape of input arguments of the operator - * this should be of same length as the vector returned by DescribeArgs - * in_shape allows unknown elements, which are checked by shape.ndim() == 0. - * For unknown shapes, InferShape will try to fill in the correct Shape in in_shape - * For known shapes, InferShape will check shape consistency - * - * common practice: set the shape of data input, and usually weight's shape can be infered - * - * \param out_shape the shape of outputs of the operator - * InferShape will modify the vector to fill output TShape - * \return if the shape inference is successful, return true, else return false. - */ - virtual bool InferShape(std::vector *in_shape, std::vector *out_shape) const = 0; - /*! - * \brief Copy this AtomicSymbol and returns a pointer to the copied object. - * this is a virtual function because different subclass of AtomicSymbol would copy differently. - * \return a pointer to the copied atomic symbol - */ - virtual AtomicSymbol* Copy() const = 0; - /*! - * \brief Bind this AtomicSymbol to a context and get back a static operator - * Bind function of AtomicSymbol does not return NArrayOperator, but static operator. - * Calling bind from the Symbol wrapper would generate a NArrayOperator. - */ - StaticOperator* Bind(Context ctx) const; - /*! - * \brief return the type string of the atomic symbol - * subclasses override this function. - */ - virtual std::string TypeString() const = 0; - /*! - * \brief Declare the input requirement of Backward pass. - * - * Only the returned list of variables will be used in Backward. - * This function is used for memory optimization. - * It is adviced to override and only return what is actually needed. - * If this function is not overriden, all the variables will be valid in Backward. - * - * \code - * // The following code declares Backward need out_grad[0], in_data[0],in_data[1] - * vector BackwardInputs(const vector &out_grad, - * const vector &in_data, - * const vector &out_data) const { - * return {out_grad[0], in_data[0], in_data[1]}; - * } - * \endcode - * \param out_grad gradient of outputs in backward pass. - * \param in_data the input data in forward pass. - * \param out_data the output data in forward pass. - * \return an integer vector indicating the input requirments - * \sa BackwardInputs - */ - virtual std::vector DeclareBackwardDependency( - const std::vector &out_grad, - const std::vector &in_data, - const std::vector &out_data) const { - // By default requires to see all the things. - // remember to override this function to get a better performance. - std::vector ret = out_grad; - ret.insert(ret.end(), in_data.begin(), in_data.end()); - ret.insert(ret.end(), out_data.begin(), out_data.end()); - return ret; - } - /*! - * \brief Get possible forward inplace options. - * This function enables optimization to reuse memory of inputs in output. - * Only override when necessary, by default in-place is disabled. - * - * \code - * // The following code says out_data[0] can share data with in_data[0] - * vector > ForwardInplaceOption(const vector &in_data, - * const vector &out_data) const { - * return {{out_data[0], in_data[0]}}; - * } - * \endcode - * \return list of pair of integers taken from the inputs vector, - * indicating possible in place operations. - */ - virtual std::vector > ForwardInplaceOption( - const std::vector &in_data, - const std::vector &out_data) const { - return std::vector >(); - } - /*! - * \brief Get possible backward inplace options. - * This function enables optimization to reuse memory of inputs in output. - * Only override when necessary, by default in-place is disabled. - * - * \code - * // The following code says in_grad[0] can share data with in_data[0] - * vector > BackwardInplaceOption( - * const std::vector &out_grad, - * const std::vector &in_data, - * const std::vector &out_data, - * const std::vector &in_grad) const { - * return {in_grad[0], in_data[0]}}; - * } - * \endcode - * \return list of pair of integers taken from the inputs vector, - * indicating possible in place operations. - */ - virtual std::vector > BackwardInplaceOption( - const std::vector &out_grad, - const std::vector &in_data, - const std::vector &out_data, - const std::vector &in_grad) const { - return std::vector >(); - } - /*! - * \brief Get Backward Input Dependency for generic types of data. - * Normally T can be pointer of Symbol::DataEntry, or NArray. - * This function will select the result list of T according to DeclareBackwardDependency. - * - * \param in_data the input data in forward pass. - * \param out_data the output data in forward pass. - * \param out_grad gradient of outputs in backward pass. - * \tparam T the generic type parameter. - * \return vector of inputs the Backward Operation depends on. - * \sa DeclareBackwardDependency - */ - template - inline std::vector BackwardInputs(const std::vector &in_data, - const std::vector &out_data, - const std::vector &out_grad) const { - int cnt = 0; - std::vector all_vec; - std::vector in_data_idx, out_data_idx, out_grad_idx; - for (size_t i = 0; i < in_data.size(); ++i) { - in_data_idx.push_back(cnt++); - all_vec.push_back(in_data[i]); - } - for (size_t i = 0; i < out_data.size(); ++i) { - out_data_idx.push_back(cnt++); - all_vec.push_back(out_data[i]); - } - for (size_t i = 0; i < out_grad.size(); ++i) { - out_grad_idx.push_back(cnt++); - all_vec.push_back(out_data[i]); - } - std::vector ret_idx = this->DeclareBackwardDependency( - in_data_idx, out_data_idx, out_grad_idx); - std::vector ret; - for (size_t i = 0; i < ret_idx.size(); ++i) { - ret.push_back(all_vec[ret_idx[i]]); - } - return ret; - } - /*! - * \brief create atomic symbol by type name - * \param type_name the type string of the AtomicSymbol - * \return a new constructed AtomicSymbol - */ - static AtomicSymbol *Create(const char* type_name); -}; +// check c++11 +#if DMLC_USE_CXX11 == 0 +#error "CXX11 was required for symbolic module" +#endif +namespace mxnet { /*! * \brief StaticGraph is the configuration of computation graphs. * This is the "configuration file" of mxnet. @@ -222,16 +35,13 @@ class StaticGraph { struct DataEntry { /*! \brief the source node id in the computation graph */ uint32_t source_id; - /*! - * \brief index of output from the source. - * If index == -1, it represents all the outputs. - */ - int32_t index; + /*! \brief index of output from the source. */ + uint32_t index; }; /*! \brief Operation Node in static graph */ struct Node { - /*! \brief wrapped atomic symbol */ - std::unique_ptr sym; + /*! \brief wrapped operator property */ + std::unique_ptr op; /*! \brief name of the node */ std::string name; /*! \brief inputs (node_id, index) for of the nodes*/ @@ -278,6 +88,7 @@ class StaticGraph { bool InferShape(std::vector *in_shape, std::vector *out_shape) const; }; + /*! * \brief Symbol is used to represent dynamically generated symbolic computation graph. * @@ -352,24 +163,15 @@ class Symbol { * \param name name of returned symbol. * \return a new Symbol which is the composition of current symbol with its arguments */ - inline Symbol operator () (const std::vector& args, - const std::string& name) const { - Symbol s = this->Copy(); - s.Compose(args, name); - return s; - } + Symbol operator () (const std::vector& args, const std::string& name) const; /*! * \brief compose with named arguments * \param kwargs keyword arguments for the symbol * \param name name of returned symbol. * \return a new symbol which is the composition of current symbol with its arguments */ - inline Symbol operator () (const std::unordered_map& kwargs, - const std::string& name) const { - Symbol s = this->Copy(); - s.Compose(kwargs, name); - return s; - } + Symbol operator () (const std::unordered_map& kwargs, + const std::string& name) const; /*! * \brief infer the shapes of outputs and unknown input arguments * \param in_shape the shape of input arguments of the operator @@ -384,12 +186,7 @@ class Symbol { * InferShape will modify the vector to fill output TShape * \return if the shape inference is successful, return true, else return false. */ - inline bool InferShape(std::vector *in_shape, - std::vector *out_shape) const { - StaticGraph g; - this->ToStaticGraph(&g); - return g.InferShape(in_shape, out_shape); - } + bool InferShape(std::vector *in_shape, std::vector *out_shape) const; /*! * \brief get number of outputs of this symbol * \return number of outputs @@ -398,54 +195,41 @@ class Symbol { return heads_.size(); } /*! - * \brief create Symbol by wrapping AtomicSymbol - * This function takes the ownership of atomic_symbol. + * \brief create Symbol by wrapping OperatorProperty + * This function takes the ownership of op * - * \param atomic_symbol the AtomicSymbol + * \param op the OperatorProperty of the Operator * \return Symbol - * \sa AtomicSymbol::Create + * \sa OperatorProperty::Create */ - static Symbol Create(AtomicSymbol *atomic_symbol); + static Symbol Create(OperatorProperty *op); /*! * \brief create equivalence of symbol from static graphs * \param graph the static graph * \return the created symbol */ static Symbol Create(const StaticGraph &graph); - /*! * \brief create equivalence of symbol by grouping the symbols together * \param symbols list of symbols * \return the grouped symbol */ - static Symbol CreateGroup(const std::vector &symbols) { - Symbol ret; - for (const auto &s : symbols) { - ret.heads_.insert(ret.heads_.end(), s.heads_.begin(), s.heads_.end()); - } - return std::move(ret); - } + static Symbol CreateGroup(const std::vector &symbols); /*! * \brief create variable symbol node * \param name name of the variable * \return the new variable */ - inline static Symbol CreateVariable(const std::string &name) { - Symbol s; - s.heads_.push_back(DataEntry(std::make_shared(nullptr, name), 0)); - return std::move(s); - } + static Symbol CreateVariable(const std::string &name); protected: - // forward declare Node + // Decalre node, internal data structure. struct Node; /*! \brief an entry that represents output data from a node */ struct DataEntry { /*! \brief the source node of this data */ std::shared_ptr source; - /*! - * \brief index of output from the source. - */ + /*! \brief index of output from the source. */ uint32_t index; /*! \brief enabled default copy constructor */ DataEntry() {} @@ -453,48 +237,15 @@ class Symbol { DataEntry(std::shared_ptr source, uint32_t index) : source(source), index(index) {} }; - /*! - * \brief Node is represents node of an operator in the symbolic graph. - * - * It stores connection to the inputs to function represented by AtomicSymbol - * NOTE on data structure: there are three types of node: - * - Normal node: contains all the necessary elements of a graph. - * - AtomicSymbol: the inputs_ is empty, represents an AtomicSymbol that has not been applied. - * - Variable: the sym_ is nullptr, represents an named Variable of tensors that can be composed. - */ - struct Node { - /*! \brief wrapped atomic symbol */ - std::unique_ptr sym; - /*! \brief name of the node */ - std::string name; - /*! \brief inputs to this node */ - std::vector inputs; - /*! - * \brief constructor - * \param sym the AtomicSymbol to construct the symbol - * \param name the name of the symbol - */ - explicit Node(AtomicSymbol* sym = nullptr, const std::string& name = "") - : sym(sym), name(name) { - } - /*! \return Whether the symbol is AtomicSymbol */ - inline bool is_atomic() const { - return inputs.size() == 0 && sym != nullptr; - } - /*! \return Whetehr the symbolc is a PlaceHolder */ - inline bool is_variable() const { - return sym == nullptr; - } - }; /*! * \brief the head nodes of Symbols * This head is only effective when */ std::vector heads_; - /*! \return whwther the symbol is AtomicSymbol */ - inline bool is_atomic() const { - return heads_.size() == 1 && heads_[0].source->is_atomic(); - } + + private: + /*! \return whwther the symbol is atomic */ + inline bool is_atomic() const; /*! * \brief Visit all the nodes in left-to-right depth first order. * @@ -505,30 +256,7 @@ class Symbol { * \tparam FVisit visiting function type */ template - inline void DFSVisit(FVisit fvisit) const { - std::vector stack; - std::unordered_set visited; - // put the head into the graph - for (auto &head : heads_) { - Node *ptr = head.source.get(); - if (visited.count(ptr) == 0) { - stack.push_back(ptr); - visited.insert(ptr); - } - } - while (!stack.empty()) { - Node* back = stack.back(); - stack.pop_back(); - fvisit(back); - for (auto it = back->inputs.rbegin(); it != back->inputs.rend(); ++it) { - Node *ptr = it->source.get(); - if (visited.count(ptr) == 0) { - stack.push_back(ptr); - visited.insert(ptr); - } - } - } - } + inline void DFSVisit(FVisit fvisit) const; /*! * \brief Find duplicate arguments in the composition * \param out the map of argument-name -> occurence count @@ -536,6 +264,48 @@ class Symbol { */ int FindDuplicateArgs(std::unordered_map *out) const; }; -#endif // DMLC_USE_CXX11 + +/*! + * \brief Executor of a computation graph. + * Executor can be created by Binding a symbol. + */ +class Executor { + public: + /*! \brief destructor */ + virtual ~Executor() {} + /*! + * \brief Perform a Forward operation of Operator + * After this operation, user can get the result by using function head. + */ + virtual void Forward() = 0; + /*! + * \brief Perform a Backward operation of the Operator. + * This must be called after Forward. + * After this operation, NArrays specified by grad_in_args_store will be updated accordingly. + * \param head_grads the gradient of head nodes to be backproped. + */ + virtual void Backward(const std::vector &head_grads) = 0; + /*! + * \brief get array of heads in the executor. + * \return array of heads in the executor. + */ + virtual const std::vector &heads() const = 0; + /*! + * \brief Create an operator by bind symbol with context and arguments. + * If user do not want to compute the gradients of i-th argument, grad_req_type[i] can be kNullOp. + * + * \param ctx the context of binding. + * \param symbol the symbol that specifies the output of Forward pass. + * \param in_args the NArray that stores the input arguments to the symbol. + * \param arg_grad_store NArray that is used to store the gradient output of the input arguments. + * \param grad_req_type requirment type of gradient saving. Can only be in {kNullOp, kAddTo, kWriteTo}. + * \return a new executor. + */ + static Executor *Bind(Symbol symbol, + Context ctx, + const std::vector &in_args, + const std::vector &arg_grad_store, + const std::vector &grad_req_type); +}; // class operator } // namespace mxnet #endif // MXNET_SYMBOLIC_H_ diff --git a/make/config.mk b/make/config.mk index dccb959c2f36..48587a4f9114 100644 --- a/make/config.mk +++ b/make/config.mk @@ -49,7 +49,7 @@ PS_PATH = NONE PS_THIRD_PATH = NONE # whether compile with rabit -USE_RABIT_PS = 1 +USE_RABIT_PS = 0 RABIT_PATH = rabit # use openmp iterator diff --git a/python/mxnet/symbol_creator.py b/python/mxnet/symbol_creator.py index d4b87e401e3b..c81deebaef11 100644 --- a/python/mxnet/symbol_creator.py +++ b/python/mxnet/symbol_creator.py @@ -61,10 +61,10 @@ def __call__(self, *args, **kwargs): param_keys = c_array(ctypes.c_char_p, param_keys) param_vals = c_array(ctypes.c_char_p, param_vals) sym_handle = SymbolHandle() - check_call(_LIB.MXSymbolCreateFromAtomicSymbol( \ - self.handle, len(param_keys), \ - param_keys, param_vals, \ - ctypes.byref(sym_handle))) + check_call(_LIB.MXSymbolCreateAtomicSymbol( + self.handle, len(param_keys), + param_keys, param_vals, + ctypes.byref(sym_handle))) if len(args) != 0 and len(symbol_kwargs) != 0: raise TypeError('%s can only accept input \ diff --git a/src/c_api.cc b/src/c_api.cc index 9620840cf3b0..d5a1a67d70c6 100644 --- a/src/c_api.cc +++ b/src/c_api.cc @@ -9,6 +9,7 @@ #include #include #include +#include #include #include #include @@ -275,7 +276,7 @@ int MXFuncInvoke(FunctionHandle fun, int MXSymbolListAtomicSymbolCreators(mx_uint *out_size, AtomicSymbolCreator **out_array) { API_BEGIN(); - auto &vec = Registry::List(); + auto &vec = Registry::List(); *out_size = static_cast(vec.size()); *out_array = (AtomicSymbolCreator*)(dmlc::BeginPtr(vec)); // NOLINT(*) API_END(); @@ -284,28 +285,28 @@ int MXSymbolListAtomicSymbolCreators(mx_uint *out_size, int MXSymbolGetAtomicSymbolName(AtomicSymbolCreator creator, const char **out) { API_BEGIN(); - AtomicSymbolEntry *e = static_cast(creator); + OperatorPropertyEntry *e = static_cast(creator); *out = e->name.c_str(); API_END(); } -int MXSymbolCreateFromAtomicSymbol(AtomicSymbolCreator creator, - int num_param, - const char **keys, - const char **vals, - SymbolHandle *out) { +int MXSymbolCreateAtomicSymbol(AtomicSymbolCreator creator, + int num_param, + const char **keys, + const char **vals, + SymbolHandle *out) { MXAPISymbolWrapper *s = new MXAPISymbolWrapper(); - AtomicSymbol *atomic_symbol = nullptr; + OperatorProperty *op = nullptr; API_BEGIN(); - AtomicSymbolEntry *e = static_cast(creator); - atomic_symbol = (*e)(); + OperatorPropertyEntry *e = static_cast(creator); + op = (*e)(); for (int i = 0; i < num_param; ++i) { - atomic_symbol->SetParam(keys[i], vals[i]); + op->SetParam(keys[i], vals[i]); } - s->sym = Symbol::Create(atomic_symbol); + s->sym = Symbol::Create(op); *out = s; - API_END_HANDLE_ERROR(delete s; delete atomic_symbol); + API_END_HANDLE_ERROR(delete s; delete op); } int MXSymbolCreateVariable(const char *name, SymbolHandle *out) { diff --git a/src/operator/composite_operator.cc b/src/operator/composite_operator.cc deleted file mode 100644 index 1853c0539000..000000000000 --- a/src/operator/composite_operator.cc +++ /dev/null @@ -1,127 +0,0 @@ -/*! - * Copyright (c) 2015 by Contributors - * \file composite_operator.cc - * \brief composite operator of mxnet - * \author Bing Xu -*/ -#include -#include -#include -#include -#include -#include - -namespace mxnet { -/*! - * \brief composite_operator interface - * composite operator is a combination of static operator from static graph - */ -class CompositeOperator : public Operator { - public: - /*! \brief destructor */ - virtual ~CompositeOperator() {} - /*! - * \brief describe property of op - * \return a bit map in int - */ - virtual int DescribeProperty() const { - // default most of layer only conatin internal state - return kContainInteralState; - } - /*! \brief Make operator by using graph - * \param ctx ctx context of the created operator - * \param in input narray - * \param grad gradient narray - * \param req gradient request - */ - void Bind(Context ctx, - const std::vector &in, - const std::vector &grad - const std::vector &req) { - ctx_ = ctx; - // infer shape - // build dict - // alloc nodes - // alloc feature map - UpdateConnection(in, grad, req); - } - /*! - * \brief Update connections data in/after bind - * \param in input narray - * \param grad gradient narray - * \param req gradient request - */ - void UpdateConnection(const std::vector &in, - const std::vector &grad, - const std::vector &req) { - CHECK_EQ(in.size() == nodes_.size()); - CHECK_EQ(grad.size() == nodes_.size()); - CHECK_EQ(req.size() == nodes_.size()); - } - /*! - * \brief perform a forward operation of operator (no change to binded NArray) - * \param opt option on Forward such as whether this is training phase - */ - virtual void Forward(Option opt) { - for (auto nid : topo_order_) { - if (nodes_[nid].is_variable) continue; - nodes_[nid].op->Forward(opt, - ctx_, - nodes_[nid].inputs, - nodes_[nid].outputs); - } - } - /*! - * \brief perform a backward operation of the operator to get the gradient - * No change to Binded NArray - */ - virtual void Backward() { - for (auto it = topo_order_.rbegin(); it < topo_order_.rend(); ++it) { - if (nodes_[*it].is_variable) continue; - nodes_[*it].op->Backward(ctx_, - nodes_[*it].outputs, - nodes_[*it].inputs, - nodes_[*it].outputs_grad, - nodes_[*it].req); - } - } - /*! - * \brief perform an extraction operation to get outputs - * \param name of symbol need to be extracted - * \return empty narray for invalid name or narray of the feature map - */ - virtual std::vector Extract(const std::string &symbol_name) { - auto it = name_dict_.find(symbol_name); - if (it == name_dict_.end()) return {}; - return nodes_[it->second].outputs; - } - private: - /*! - * \brief Structure for OpNode - */ - struct OpNode { - /*! \brief Static Operator */ - std::unique_ptr op; - /*! \brief inputs (init after setting output correctly) */ - std::vector inputs; - /*! \brief outputs */ - std::vector outputs; - /*! \brief gradient for output */ - std::vector outputs_grad; - /*! \brief gradient req for grad */ - std::vector req; - /*! \brief is variable */ - bool is_variable; - }; - /*! \brief connections */ - std::vector nodes_; - /*! \brief topo order of connections */ - std::vector topo_order_; - /*! \brief static graph */ - StaticGraph graph_; - /*! \brief running context */ - RunContext ctx_; - /*! \brief name id dictionary */ - std::unordered_map name_dict_; -}; // class CompositeOperator -} // namespace mxnet diff --git a/src/operator/static_operator/fully_connect_op-inl.h b/src/operator/fully_connected-inl.h similarity index 82% rename from src/operator/static_operator/fully_connect_op-inl.h rename to src/operator/fully_connected-inl.h index 9bdca812fa20..92a95a2ada2c 100644 --- a/src/operator/static_operator/fully_connect_op-inl.h +++ b/src/operator/fully_connected-inl.h @@ -3,37 +3,33 @@ * \file fully_connect_op-inl.h * \brief fully connect operator and symbol */ -#ifndef MXNET_OPERATOR_STATIC_OPERATOR_FULLY_CONNECT_OP_INL_H_ -#define MXNET_OPERATOR_STATIC_OPERATOR_FULLY_CONNECT_OP_INL_H_ +#ifndef MXNET_OPERATOR_FULLY_CONNECTED_INL_H_ +#define MXNET_OPERATOR_FULLY_CONNECTED_INL_H_ #include #include -#include #include #include #include -#include "./static_operator_common.h" +#include "./operator_common.h" #include "./param.h" namespace mxnet { namespace op { + // Declare enumeration of input order to make code more intuitive. // These enums are only visible within this header -enum FullyConnectOpInputs {kData, kWeight, kBias}; -enum FullyConnectOpOutputs {kOut}; +enum FullyConnectedOpInputs {kData, kWeight, kBias}; +enum FullyConnectedOpOutputs {kOut}; /** - * \brief This is the implementation of fully connected layer. - * + * \brief This is the implementation of fully connected operator. * \tparam xpu The device that the op will be executed on. */ template -class FullyConnectOp : public StaticOperator { +class FullyConnectedOp : public Operator { public: - /*! - * \brief constructor with parameters. Used in Bind() in corresponding symbol. - */ - explicit FullyConnectOp(Param p) { + explicit FullyConnectedOp(Param p) { this->param_ = p; } @@ -97,17 +93,14 @@ class FullyConnectOp : public StaticOperator { private: /** The param of the fully connected layer.*/ Param param_; -}; // class FullyConnectOp +}; // class FullyConnectedOp -// Decalre factory function, used for dispatch specialization +// Decalre Factory function, used for dispatch specialization template -StaticOperator* CreateFullyConnectedOp(Param param); +Operator* CreateFullyConnectedOp(Param param); #if DMLC_USE_CXX11 -/** - * @brief The symbol part of the fully connected layer. - */ -class FullyConnectSymbol : public AtomicSymbol { +class FullyConnectedProp : public OperatorProperty { public: virtual std::vector ListArguments() const { if (param_.no_bias == 0) { @@ -144,14 +137,14 @@ class FullyConnectSymbol : public AtomicSymbol { return true; } - virtual AtomicSymbol* Copy() const { - FullyConnectSymbol* fc_sym = new FullyConnectSymbol(); + virtual OperatorProperty* Copy() const { + FullyConnectedProp* fc_sym = new FullyConnectedProp(); fc_sym->param_ = this->param_; return fc_sym; } virtual std::string TypeString() const { - return "FullyConnected"; + return "FullyConnecteded"; } // decalre dependency and inplace optimization options virtual std::vector DeclareBackwardDependency( @@ -169,15 +162,12 @@ class FullyConnectSymbol : public AtomicSymbol { return {{in_grad[kData], in_data[kData]}}; } - // bind function - StaticOperator* Bind(Context ctx) const; + Operator* CreateOperator(Context ctx) const; private: - /** The param of the fully connected layer.*/ Param param_; -}; // class FullyConnectSymbol +}; // class FullyConnectedSymbol #endif } // namespace op } // namespace mxnet - -#endif // MXNET_OPERATOR_STATIC_OPERATOR_FULLY_CONNECT_OP_INL_H_ +#endif // MXNET_OPERATOR_FULLY_CONNECTED_INL_H_ diff --git a/src/operator/fully_connected.cc b/src/operator/fully_connected.cc new file mode 100644 index 000000000000..362d3c5698aa --- /dev/null +++ b/src/operator/fully_connected.cc @@ -0,0 +1,22 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file fully_connected.cc + * \brief fully connect operator +*/ +#include +#include "./fully_connected-inl.h" +namespace mxnet { +namespace op { +template<> +Operator* CreateFullyConnectedOp(Param param) { + return new FullyConnectedOp(param); +} + +// DO_BIND_DISPATCH comes from static_operator_common.h +Operator* FullyConnectedProp::CreateOperator(Context ctx) const { + DO_BIND_DISPATCH(CreateFullyConnectedOp, param_); +} + +REGISTER_OP_PROPERTY(FullyConnected, FullyConnectedProp); +} // namespace op +} // namespace mxnet diff --git a/src/operator/fully_connected.cu b/src/operator/fully_connected.cu new file mode 100644 index 000000000000..223ef5166cc9 --- /dev/null +++ b/src/operator/fully_connected.cu @@ -0,0 +1,14 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file fully_connected.cu + * \brief fully connect operator +*/ +#include "./fully_connected-inl.h" +namespace mxnet { +namespace op { +template<> +Operator* CreateFullyConnectedOp(Param param) { + return new FullyConnectedOp(param); +} +} // namespace op +} // namespace mxnet diff --git a/src/operator/static_operator/static_operator_common.h b/src/operator/operator_common.h similarity index 88% rename from src/operator/static_operator/static_operator_common.h rename to src/operator/operator_common.h index f90b9ffd6ce3..87b581f28278 100644 --- a/src/operator/static_operator/static_operator_common.h +++ b/src/operator/operator_common.h @@ -1,17 +1,17 @@ /*! * Copyright (c) 2015 by Contributors - * \file static_operator_common.h + * \file operator_common.h * \brief common internal header of most operators * this header includes utility functions operator can use - * common type definitions * \author Bing Xu */ -#ifndef MXNET_OPERATOR_STATIC_OPERATOR_STATIC_OPERATOR_COMMON_H_ -#define MXNET_OPERATOR_STATIC_OPERATOR_STATIC_OPERATOR_COMMON_H_ +#ifndef MXNET_OPERATOR_OPERATOR_COMMON_H_ +#define MXNET_OPERATOR_OPERATOR_COMMON_H_ #include #include #include + namespace mxnet { namespace op { /*! @@ -49,7 +49,7 @@ inline void ShapeAssignCheck(TShape &out, const TS &shape) { // NOLINT(*) } } -// definition of micro +// helper macro to implement bind dispatch #if MXNET_USE_CUDA #define DO_BIND_DISPATCH(Method, ...) \ if (ctx.dev_mask == cpu::kDevMask) { \ @@ -69,4 +69,4 @@ inline void ShapeAssignCheck(TShape &out, const TS &shape) { // NOLINT(*) } // namespace op } // namespace mxnet -#endif // MXNET_OPERATOR_STATIC_OPERATOR_STATIC_OPERATOR_COMMON_H_ +#endif // MXNET_OPERATOR_OPERATOR_COMMON_H_ diff --git a/src/operator/static_operator/param.h b/src/operator/param.h similarity index 90% rename from src/operator/static_operator/param.h rename to src/operator/param.h index f6e91293eca3..e1f6b4ee58d8 100644 --- a/src/operator/static_operator/param.h +++ b/src/operator/param.h @@ -1,11 +1,13 @@ /*! * Copyright (c) 2015 by Contributors * \file param.h - * \brief operator params + * \brief Common operator parameters * \author Bing Xu */ -#ifndef MXNET_OPERATOR_STATIC_OPERATOR_PARAM_H_ -#define MXNET_OPERATOR_STATIC_OPERATOR_PARAM_H_ +#ifndef MXNET_OPERATOR_PARAM_H_ +#define MXNET_OPERATOR_PARAM_H_ + +#include namespace mxnet { namespace op { @@ -39,6 +41,12 @@ struct Param { int num_input_node; /*! \brief reserved fields, for future compatibility */ int reserved[64]; + + // constructor + Param() { + memset(this, 0, sizeof(Param)); + } + inline void SetParam(const char *name, const char* val) { if (!strcmp(name, "nhidden")) num_hidden = atoi(val); if (!strcmp(name, "num_input_node")) num_input_node = atoi(val); @@ -68,6 +76,5 @@ struct Param { } // namespace op } // namespace mxnet -#endif // MXNET_OPERATOR_STATIC_OPERATOR_PARAM_H_ - +#endif // MXNET_OPERATOR_PARAM_H_ diff --git a/src/operator/static_operator/fully_connect_op.cc b/src/operator/static_operator/fully_connect_op.cc deleted file mode 100644 index 69687024384e..000000000000 --- a/src/operator/static_operator/fully_connect_op.cc +++ /dev/null @@ -1,22 +0,0 @@ -/*! - * Copyright (c) 2015 by Contributors - * \file fully_connect_sym.cc - * \brief fully connect operator symbol -*/ -#include -#include "../static_operator/fully_connect_op-inl.h" -namespace mxnet { -namespace op { -template<> -StaticOperator* CreateFullyConnectedOp(Param param) { - return new FullyConnectOp(param); -} - -// DO_BIND_DISPATCH comes from static_operator_common.h -StaticOperator* FullyConnectSymbol::Bind(Context ctx) const { - DO_BIND_DISPATCH(CreateFullyConnectedOp, param_); -} - -REGISTER_ATOMIC_SYMBOL(FullyConnected, FullyConnectSymbol); -} // namespace op -} // namespace mxnet diff --git a/src/operator/static_operator/fully_connect_op.cu b/src/operator/static_operator/fully_connect_op.cu deleted file mode 100644 index 2ff5b565ee88..000000000000 --- a/src/operator/static_operator/fully_connect_op.cu +++ /dev/null @@ -1,16 +0,0 @@ -/*! - * Copyright (c) 2015 by Contributors - * \file fully_connect_sym.cu - * \brief fully connect operator symbol -*/ -#include "./fully_connect_op-inl.h" -namespace mxnet { -namespace op { - -template<> -StaticOperator* CreateFullyConnectedOp(Param param) { - return new FullyConnectOp(param); -} - -} // namespace op -} // namespace mxnet diff --git a/src/operator/static_operator/static_operator-inl.h b/src/operator/static_operator/static_operator-inl.h deleted file mode 100644 index f03a6a51532e..000000000000 --- a/src/operator/static_operator/static_operator-inl.h +++ /dev/null @@ -1,49 +0,0 @@ -/*! - * Copyright (c) 2015 by Contributors - * \file static_operator-inl.h - * \brief static device invarient code to create operators - * \author Bing Xu -*/ -#ifndef MXNET_OPERATOR_STATIC_OPERATOR_STATIC_OPERATOR_INL_H_ -#define MXNET_OPERATOR_STATIC_OPERATOR_STATIC_OPERATOR_INL_H_ -#include -#include -#include -#include "./mshadow_op.h" -#include "./activation_op-inl.h" -#include "./convolution_op-inl.h" -#include "./pooling_op-inl.h" -#include "./reshape_op-inl.h" -#include "./dropout_op-inl.h" - -namespace mxnet { -namespace op { -/*! - * \brief device invariant function to create operators - * \param type the type of operator - * \tparam xpu the device type we are at - */ -template -inline StaticOperator *CreateOperator_(OpType type, mshadow::Random *prnd) { - switch (type) { - case kReLU: - return new ActivationOp(); - case kConv: - return new ConvolutionOp(); - case kMaxPooling: - return new PoolingOp(); - case kAvgPooling: - return new PoolingOp(); - case kFlatten: - return new ReshapeOp(); - case kReshape: - return new ReshapeOp(); - case kDropout: - return new DropoutOp(prnd); - default: LOG(FATAL) << "unknown OpType"; - } - return NULL; -} -} // namespace op -} // namespace mxnet -#endif // MXNET_OPERATOR_STATIC_OPERATOR_STATIC_OPERATOR_INL_H_ diff --git a/src/operator/static_operator/static_operator.cc b/src/operator/static_operator/static_operator.cc deleted file mode 100644 index 671ef76f2f9c..000000000000 --- a/src/operator/static_operator/static_operator.cc +++ /dev/null @@ -1,44 +0,0 @@ -/*! - * Copyright (c) 2015 by Contributors - * \file static_operator.cc - * \brief - * \author: Bing Xu - */ -#include -#include -#include -#include -#include "./static_operator_common.h" - -namespace mxnet { -namespace op { -/** - * @brief return a OpType based on string description - * - * @param type the string description of operators - * @return the OpType indicated the type of operators - */ -OpType GetOpType(const char *type) { - if (!strcmp(type, "relu")) return kReLU; - if (!strcmp(type, "fullc")) return kFullc; - LOG(FATAL) << "unknown operator type " << type; - return kReLU; -} -} // namespace op - -StaticOperator *StaticOperator::Create(const char *type, - Context ctx) { - op::OpType otype = op::GetOpType(type); - if (ctx.dev_mask == cpu::kDevMask) { - return op::CreateOperator(otype); - } - if (ctx.dev_mask == gpu::kDevMask) { -#if MXNET_USE_CUDA - return op::CreateOperator(otype); -#else - LOG(FATAL) << "GPU is not enabled"; -#endif - } - return NULL; -} // namespace op -} // namespace mxnet diff --git a/src/operator/static_operator/static_operator_cpu.cc b/src/operator/static_operator/static_operator_cpu.cc deleted file mode 100644 index 5b6ea861213b..000000000000 --- a/src/operator/static_operator/static_operator_cpu.cc +++ /dev/null @@ -1,20 +0,0 @@ -/*! - * Copyright (c) 2015 by Contributors - * \file static_operator_cpu.cc - * \brief CPU specialization of operator codes - * \author Bing Xu -*/ -#include "./static_operator-inl.h" - -namespace mxnet { -namespace op { -// todo add managing for prnd -mshadow::Random prnd_cpu(0); - -template<> -StaticOperator *CreateOperator(OpType type) { - return CreateOperator_(type, &prnd_cpu); -} - -} // namespace op -} // namespace mxnet diff --git a/src/operator/static_operator/static_operator_gpu.cu b/src/operator/static_operator/static_operator_gpu.cu deleted file mode 100644 index a66167431dd1..000000000000 --- a/src/operator/static_operator/static_operator_gpu.cu +++ /dev/null @@ -1,22 +0,0 @@ -/*! - * Copyright (c) 2015 by Contributors - * \file static_operator_gpu.cu - * \brief GPU specialization of operator code - * \author Bing Xu -*/ -#include -#include "static_operator-inl.h" - -namespace mxnet { -namespace op { - -mshadow::Random prnd_gpu(0); - -template<> -StaticOperator *CreateOperator(OpType type) { - return CreateOperator_(type, &prnd_gpu); -} - -} // namespace op -} // namespace mxnet - diff --git a/src/operator/static_operator_wrapper.cc b/src/operator/static_operator_wrapper.cc deleted file mode 100644 index 1690d067c6e6..000000000000 --- a/src/operator/static_operator_wrapper.cc +++ /dev/null @@ -1,97 +0,0 @@ -/*! - * Copyright (c) 2015 by Contributors - * \file static_operator.cc - * \brief the implementation of static operator - * \author Naiyan Wang - */ -#include -#include -#include -#include -#include -#include - -namespace mxnet { -namespace op { -/*! - * \brief StaticOperatorWrapper that wraps a static_operator - * This class do not need to be seen by others, so it sit in cc file. - * \sa Operator, StaticOperator - */ -class StaticOperatorWrapper: public Operator { - public: - StaticOperatorWrapper(StaticOperator* op, Context ctx) - : op_(op), ctx_(ctx) {} - - virtual ~StaticOperatorWrapper() { - delete op_; - } - - virtual int DescribeProperty() const { - return op_->DescribeProperty(); - } - - virtual void Forward(Option opt, - RunContext ctx, - const std::vector &in_data, - const std::vector &out_data) { - std::vector used_var; - std::vector mutate_var; - std::vector in; - std::vector out; - for (size_t i = 0; i < in_data.size(); ++i) { - used_var.push_back(in_data[i].var()); - in.push_back(in_data[i].data()); - } - for (size_t i = 0; i < out_data.size(); ++i) { - mutate_var.push_back(out_data[i].var()); - out.push_back(out_data[i].data()); - } - DAGEngine::Get()->Push([this, opt, ctx, in, out](RunContext ctx) { - op_->Forward(opt, ctx, in, out); - }, ctx_, used_var, mutate_var); - } - - virtual void Backward(RunContext ctx, - const std::vector &grad_next, - const std::vector &in_data, - const std::vector &out_data, - const std::vector &out_grad, - const std::vector &req) { - std::vector used_var; - std::vector mutate_var; - std::vector grad_in; - std::vector grad_out; - std::vector data_in; - std::vector data_out; - for (size_t i = 0; i < grad_next.size(); ++i) { - used_var.push_back(grad_next[i].var()); - grad_in.push_back(grad_next[i].data()); - } - for (size_t i = 0; i < in_data.size(); ++i) { - used_var.push_back(in_data[i].var()); - data_in.push_back(in_data[i].data()); - } - for (size_t i = 0; i < out_grad.size(); ++i) { - mutate_var.push_back(out_grad[i].var()); - grad_out.push_back(out_grad[i].data()); - } - DAGEngine::Get()->Push([this, ctx, grad_in, grad_out, data_in, data_out, req](RunContext ctx) { - op_->Backward(ctx, grad_in, data_in, data_out, grad_out, req); - }, ctx_, used_var, mutate_var); - } - - private: - /* \brief the static operator */ - StaticOperator* op_; - /** \brief the global context denots the device info. */ - Context ctx_; -}; -} // namespace op - -// implements CreateWrapper -Operator *Operator::CreateWrapper(StaticOperator *op, Context ctx) { - return new op::StaticOperatorWrapper(op, ctx); -} - -} // namespace mxnet diff --git a/src/registry.cc b/src/registry.cc index 04f391cb617c..42fef1df3423 100644 --- a/src/registry.cc +++ b/src/registry.cc @@ -30,7 +30,7 @@ template NArrayFunctionEntry &Registry::Register(const std: template Registry *Registry::Get(); #endif -template AtomicSymbolEntry &Registry::Register(const std::string& name); -template Registry *Registry::Get(); +template OperatorPropertyEntry &Registry::Register(const std::string& name); +template Registry *Registry::Get(); } // namespace mxnet diff --git a/src/symbol/static_graph.cc b/src/symbol/static_graph.cc index ce54ad818bfe..5419e26afe86 100644 --- a/src/symbol/static_graph.cc +++ b/src/symbol/static_graph.cc @@ -8,6 +8,7 @@ #include #include +namespace mxnet { std::vector StaticGraph::TopoSort() const { std::vector out_degree(nodes.size(), 0); for (const Node &n : nodes) { @@ -38,20 +39,19 @@ std::vector StaticGraph::TopoSort() const { return std::move(ret); } -bool StaticGraph::InferShape(const std::vector &topo_order, - std::vector > *node_out_shapes) const { - bool success = true; +bool StaticGraph::InferNodeShapes(const std::vector &topo_order, + std::vector > *node_out_shapes) const { for (uint32_t nid : topo_order) { const Node &node = nodes[nid]; - if (node.sym != nullptr) { + if (node.op != nullptr) { std::vector in_shape; for (const DataEntry &e : node.inputs) { - in_shape.push_back(node_out_shapes[e.source_id][e.index]); + in_shape.push_back((*node_out_shapes)[e.source_id][e.index]); } - if (!node.sym->InferShape(&in_shape, &(*node_out_shapes)[nid])) return false; + if (!node.op->InferShape(&in_shape, &(*node_out_shapes)[nid])) return false; for (size_t i = 0; i < node.inputs.size(); ++i) { const DataEntry &e = node.inputs[i]; - node_out_shapes[e.source_id][e.index] = in_shape[i]; + (*node_out_shapes)[e.source_id][e.index] = in_shape[i]; } } } @@ -63,23 +63,25 @@ bool StaticGraph::InferShape(std::vector *in_shape, std::vector > node_out_shapes(nodes.size()); for (size_t i = 0; i < nodes.size(); ++i) { int nout = 1; - if (nodes[i].sym != nullptr) { - nout = nodes[i].sym->NumReturns(); + if (nodes[i].op != nullptr) { + nout = nodes[i].op->NumReturns(); } node_out_shapes[i].resize(nout); } CHECK(in_shape->size() == arg_nodes.size()) << "Wrong number of inputs to infer shape"; for (size_t i = 0; i < arg_nodes.size(); ++i) { - node_out_shapes[nid][0] = (*in_shape)[i]; + node_out_shapes[arg_nodes[i]][0] = (*in_shape)[i]; } if (!InferNodeShapes(this->TopoSort(), &node_out_shapes)) return false; for (size_t i = 0; i < arg_nodes.size(); ++i) { - (*in_shape)[i] = node_out_shapes[nid][0]; + (*in_shape)[i] = node_out_shapes[arg_nodes[i]][0]; } for (size_t i = 0; i < outputs.size(); ++i) { DataEntry e = outputs[i]; (*out_shape)[i] = node_out_shapes[e.source_id][e.index]; } + return true; } +} // namespace mxnet diff --git a/src/symbol/symbol.cc b/src/symbol/symbol.cc index e3700dd127f4..d7a56528fb77 100644 --- a/src/symbol/symbol.cc +++ b/src/symbol/symbol.cc @@ -5,22 +5,103 @@ */ #include #include -#include #include #include #include #include namespace mxnet { -// copy the symbol +/*! + * \brief Node is represents node of an operator in the symbolic graph. + * + * It stores connection to the inputs to function represented by OperatorProperty + * NOTE on data structure: there are three types of node: + * - Normal node: contains all the necessary elements of a graph. + * - OperatorProperty: the inputs_ is empty, represents an OperatorProperty that has not been applied. + * - Variable: the sym_ is nullptr, represents an named Variable of tensors that can be composed. + */ +struct Symbol::Node { + /*! \brief Operator of this node */ + std::unique_ptr op; + /*! \brief name of the node */ + std::string name; + /*! \brief inputs to this node */ + std::vector inputs; + /*! + * \brief constructor + * \param op the OperatorProperty to construct the Node + * \param name the name of the symbol + */ + explicit Node(OperatorProperty* op = nullptr, const std::string& name = "") + : op(op), name(name) { + } + /*! \return Whether the symbol is atomic */ + inline bool is_atomic() const { + return inputs.size() == 0 && op != nullptr; + } + /*! \return Whether it is unit variable */ + inline bool is_variable() const { + return op == nullptr; + } +}; + +/*! \return whwther the symbol is atomic */ +inline bool Symbol::is_atomic() const { + return heads_.size() == 1 && heads_[0].source->is_atomic(); +} +// implementation of template functions +template +inline void Symbol::DFSVisit(FVisit fvisit) const { + std::vector stack; + std::unordered_set visited; + // put the head into the graph + for (auto &head : heads_) { + Node *ptr = head.source.get(); + if (visited.count(ptr) == 0) { + stack.push_back(ptr); + visited.insert(ptr); + } + } + while (!stack.empty()) { + Node* back = stack.back(); + stack.pop_back(); + fvisit(back); + for (auto it = back->inputs.rbegin(); it != back->inputs.rend(); ++it) { + Node *ptr = it->source.get(); + if (visited.count(ptr) == 0) { + stack.push_back(ptr); + visited.insert(ptr); + } + } + } +} + +int Symbol::FindDuplicateArgs(std::unordered_map *out) const { + out->clear(); + int max_dup = 1; + this->DFSVisit([out, &max_dup](Node *node) { + if (node->is_variable()) { + auto iter = out->find(node->name); + if (iter == out->end()) { + (*out)[node->name] = 1; + } else { + ++iter->second; + max_dup = std::max(max_dup, iter->second); + } + } + }); + return max_dup; +} + +// public functions Symbol Symbol::Copy() const { std::unordered_map > old_new; // use DFSVisit to copy all the nodes this->DFSVisit([&old_new](Node *node) { - if (node->sym == nullptr) { + if (node->op == nullptr) { old_new[node] = std::make_shared(nullptr, node->name); } else { - old_new[node] = std::make_shared(node->sym->Copy(), node->name); + old_new[node] = std::make_shared(node->op->Copy(), node->name); } }); // connect nodes of new graph @@ -40,7 +121,7 @@ Symbol Symbol::Copy() const { void Symbol::Print(std::ostream &os) const { if (this->is_atomic()) { - os << "AtomicSymbol "<< " Type:" << heads_[0].source->sym->TypeString() << '\n' + os << "AtomicFunction "<< " Type:" << heads_[0].source->op->TypeString() << '\n' << "Inputs:"; std::vector args = this->ListArguments(); for (size_t i = 0; i < args.size(); ++i) { @@ -57,7 +138,7 @@ void Symbol::Print(std::ostream &os) const { if (node->is_variable()) { os << "Variable:" << node->name << '\n'; } else { - os << "Name: " << node->name << " Type:" << node->sym->TypeString() << '\n' + os << "Name: " << node->name << " Type:" << node->op->TypeString() << '\n' << "Inputs:\n"; for (size_t i = 0; i < node->inputs.size(); ++i) { os << "\targ[" << i << "]=" << node->inputs[i].source->name @@ -68,21 +149,49 @@ void Symbol::Print(std::ostream &os) const { } } -int Symbol::FindDuplicateArgs(std::unordered_map *out) const { - out->clear(); - int max_dup = 1; - this->DFSVisit([out, &max_dup](Node *node) { - if (node->is_variable()) { - auto iter = out->find(node->name); - if (iter == out->end()) { - (*out)[node->name] = 1; - } else { - ++iter->second; - max_dup = std::max(max_dup, iter->second); +std::vector Symbol::ListArguments() const { + std::vector ret; + if (this->is_atomic()) { + return heads_[0].source->op->ListArguments(); + } else { + this->DFSVisit([&ret](Node *node) { + if (node->is_variable()) { + ret.push_back(node->name); } + }); + return ret; + } +} + +std::vector Symbol::ListReturns() const { + std::vector ret; + for (auto &head : heads_) { + if (head.source->is_variable()) { + ret.push_back(head.source->name); + } else { + // TODO(bing) rethink about output naming + auto &hname = head.source->name; + std::string rname = head.source->op->ListReturns()[head.index]; + if (hname.length() == 0) { + ret.push_back(std::move(rname)); + } else { + ret.push_back(hname + '_' + rname); } - }); - return max_dup; + } + } + return std::move(ret); +} + +Symbol Symbol::operator[] (size_t index) const { + size_t nreturn = NumReturns(); + CHECK_LT(index, nreturn) << "Symbol only accept nonnegative index"; + if (nreturn == 1) { + return *this; + } else { + Symbol s; + s.heads_.push_back(heads_[index]); + return s; + } } void Symbol::Compose(const std::vector& args, @@ -98,7 +207,7 @@ void Symbol::Compose(const std::vector& args, // TODO(bing) consider partial assignments if (this->is_atomic()) { // atomic symbol do not have place holder for all the arguments - std::vector req_args = heads_[0].source->sym->ListArguments(); + std::vector req_args = heads_[0].source->op->ListArguments(); CHECK_EQ(args.size(), req_args.size()) << "Incorrect number of arguments, requires " << req_args.size() << ", provided " << args.size(); @@ -154,7 +263,7 @@ void Symbol::Compose(const std::unordered_map& kwargs, size_t nmatched = 0; if (this->is_atomic()) { // atomic symbol do not have place holder for all the arguments - std::vector req_args = heads_[0].source->sym->ListArguments(); + std::vector req_args = heads_[0].source->op->ListArguments(); heads_[0].source->inputs.resize(req_args.size()); for (size_t i = 0; i < req_args.size(); ++i) { auto iter = kwargs.find(req_args[i]); @@ -235,55 +344,31 @@ void Symbol::Compose(const std::unordered_map& kwargs, } } -Symbol Symbol::operator[] (size_t index) const { - size_t nreturn = NumReturns(); - CHECK_LT(index, nreturn) << "Symbol only accept nonnegative index"; - if (nreturn == 1) { - return *this; - } else { - Symbol s; - s.heads_.push_back(heads_[index]); - return s; - } +Symbol Symbol::operator () (const std::vector& args, + const std::string& name) const { + Symbol s = this->Copy(); + s.Compose(args, name); + return s; } -std::vector Symbol::ListArguments() const { - std::vector ret; - if (this->is_atomic()) { - return heads_[0].source->sym->ListArguments(); - } else { - this->DFSVisit([&ret](Node *node) { - if (node->is_variable()) { - ret.push_back(node->name); - } - }); - return ret; - } +Symbol Symbol::operator () (const std::unordered_map& kwargs, + const std::string& name) const { + Symbol s = this->Copy(); + s.Compose(kwargs, name); + return s; } -std::vector Symbol::ListReturns() const { - std::vector ret; - for (auto &head : heads_) { - if (head.source->is_variable()) { - ret.push_back(head.source->name); - } else { - // TODO(bing) rethink about output naming - auto &hname = head.source->name; - std::string rname = head.source->sym->ListReturns()[head.index]; - if (hname.length() == 0) { - ret.push_back(std::move(rname)); - } else { - ret.push_back(hname + '_' + rname); - } - } - } - return std::move(ret); +bool Symbol::InferShape(std::vector *in_shape, + std::vector *out_shape) const { + StaticGraph g; + this->ToStaticGraph(&g); + return g.InferShape(in_shape, out_shape); } -Symbol Symbol::Create(AtomicSymbol *atomic_symbol) { +Symbol Symbol::Create(OperatorProperty *op) { // use special representation for atomic symbol - auto node = std::make_shared(atomic_symbol, ""); - size_t nret = atomic_symbol->NumReturns(); + auto node = std::make_shared(op, ""); + size_t nret = op->NumReturns(); Symbol s; for (uint32_t i = 0; i < nret; ++i) { s.heads_.push_back(DataEntry(node, i)); @@ -291,6 +376,20 @@ Symbol Symbol::Create(AtomicSymbol *atomic_symbol) { return s; } +Symbol Symbol::CreateGroup(const std::vector &symbols) { + Symbol ret; + for (const auto &s : symbols) { + ret.heads_.insert(ret.heads_.end(), s.heads_.begin(), s.heads_.end()); + } + return std::move(ret); +} + +Symbol Symbol::CreateVariable(const std::string &name) { + Symbol s; + s.heads_.push_back(DataEntry(std::make_shared(nullptr, name), 0)); + return std::move(s); +} + void Symbol::ToStaticGraph(StaticGraph *out_graph) const { // TODO(bing): Check unique name std::vector node_order; @@ -309,10 +408,10 @@ void Symbol::ToStaticGraph(StaticGraph *out_graph) const { // setup nodes out_graph->nodes.resize(node_index.size()); for (uint32_t nid = 0; nid < node_order.size(); ++nid) { - if (node_order[nid]->sym != nullptr) { - out_graph->nodes[nid].sym.reset(node_order[nid]->sym->Copy()); + if (node_order[nid]->op != nullptr) { + out_graph->nodes[nid].op.reset(node_order[nid]->op->Copy()); } else { - out_graph->nodes[nid].sym.reset(nullptr); + out_graph->nodes[nid].op.reset(nullptr); } out_graph->nodes[nid].name = node_order[nid]->name; auto &inputs = out_graph->nodes[nid].inputs;