From cc2195f44f7996d0d8079747724a4bcc5b8ccfce Mon Sep 17 00:00:00 2001 From: Bing Xu Date: Tue, 11 Aug 2015 20:21:24 -0600 Subject: [PATCH 1/2] chg interface --- ...osite_operator.h => composite_operator.cc} | 102 +++++++++--------- 1 file changed, 49 insertions(+), 53 deletions(-) rename src/operator/{composite_operator.h => composite_operator.cc} (52%) diff --git a/src/operator/composite_operator.h b/src/operator/composite_operator.cc similarity index 52% rename from src/operator/composite_operator.h rename to src/operator/composite_operator.cc index 12297dc41c43..1853c0539000 100644 --- a/src/operator/composite_operator.h +++ b/src/operator/composite_operator.cc @@ -1,11 +1,9 @@ /*! * Copyright (c) 2015 by Contributors - * \file composite_operator.h + * \file composite_operator.cc * \brief composite operator of mxnet * \author Bing Xu */ -#ifndef MXNET_OPERATOR_COMPOSITE_OPERATOR_H_ -#define MXNET_OPERATOR_COMPOSITE_OPERATOR_H_ #include #include #include @@ -39,70 +37,65 @@ class CompositeOperator : public Operator { void Bind(Context ctx, const std::vector &in, const std::vector &grad - const std::vector &req); + const std::vector &req) { + ctx_ = ctx; + // infer shape + // build dict + // alloc nodes + // alloc feature map + UpdateConnection(in, grad, req); + } /*! - * \brief perform a forward operation of operator, save the output to NArray - * This method only pushes an execution request to the DAG engine, and - * return immediately. Actual execution is conducted by the DAG engine. - * \param opt option on Forward such as whether this is training phase - * \param ctx runtime context - * \param in_data array of input data, it is const - * \param out_data array of output data, - * the space of NArray in out_data must be pre-allocated with InferShape - * \sa NArray + * \brief Update connections data in/after bind + * \param in input narray + * \param grad gradient narray + * \param req gradient request */ - virtual void Forward(Option opt, - RunContext ctx, - const std::vector &in_data, - const std::vector &out_data); + void UpdateConnection(const std::vector &in, + const std::vector &grad, + const std::vector &req) { + CHECK_EQ(in.size() == nodes_.size()); + CHECK_EQ(grad.size() == nodes_.size()); + CHECK_EQ(req.size() == nodes_.size()); + } /*! * \brief perform a forward operation of operator (no change to binded NArray) * \param opt option on Forward such as whether this is training phase */ - virtual void Forward(Option opt); - /*! - * \brief perform a backward operation of the operator to get the gradient - * This method only pushes an execution request to the DAG engine, and - * return immediately. Actual execution is conducted by the DAG engine. - * \param ctx runtime context - * \param grad_next the gradient value of the output of the operator, used by chain rule. - * \param in_data the array of input data - * \param out_grad array of output gradient - * \param req request types of the gradient saving operation - * only inplace will change input data - * \sa GradReqType, NArray - */ - virtual void Backward(RunContext ctx, - const std::vector &grad_next, - const std::vector &in_data, - const std::vector &out_grad, - const std::vector &req); + virtual void Forward(Option opt) { + for (auto nid : topo_order_) { + if (nodes_[nid].is_variable) continue; + nodes_[nid].op->Forward(opt, + ctx_, + nodes_[nid].inputs, + nodes_[nid].outputs); + } + } /*! * \brief perform a backward operation of the operator to get the gradient * No change to Binded NArray */ - virtual void Backward(); + virtual void Backward() { + for (auto it = topo_order_.rbegin(); it < topo_order_.rend(); ++it) { + if (nodes_[*it].is_variable) continue; + nodes_[*it].op->Backward(ctx_, + nodes_[*it].outputs, + nodes_[*it].inputs, + nodes_[*it].outputs_grad, + nodes_[*it].req); + } + } /*! - * \brief perform an extraction operation to get feature map + * \brief perform an extraction operation to get outputs * \param name of symbol need to be extracted * \return empty narray for invalid name or narray of the feature map */ - virtual NArray Extract(const std::string &symbol_name); - + virtual std::vector Extract(const std::string &symbol_name) { + auto it = name_dict_.find(symbol_name); + if (it == name_dict_.end()) return {}; + return nodes_[it->second].outputs; + } private: - /*! - * \brief Update connections data in/after bind - * \param in input narray - * \param grad gradient narray - * \param req gradient request - */ - void UpdateConnection(const std::vector &in, - const std::vector &grad, - const std::vector &req); - /*! - * \brief Allocate each op node - */ - void AllocateNodes(RunContext ctx); /*! * \brief Structure for OpNode */ @@ -126,6 +119,9 @@ class CompositeOperator : public Operator { std::vector topo_order_; /*! \brief static graph */ StaticGraph graph_; + /*! \brief running context */ + RunContext ctx_; + /*! \brief name id dictionary */ + std::unordered_map name_dict_; }; // class CompositeOperator } // namespace mxnet -#endif // MXNET_OPERATOR_COMPOSITE_OPERATOR_H_ From 270bfb28e77c5800a2b8f38f2689f4d43fa17faf Mon Sep 17 00:00:00 2001 From: Bing Xu Date: Thu, 13 Aug 2015 23:10:33 -0600 Subject: [PATCH 2/2] discussed interface --- Makefile | 6 +- include/mxnet/base.h | 33 ---- include/mxnet/c_api.h | 2 + include/mxnet/operator.h | 142 +++++++++--------- include/mxnet/symbolic.h | 130 +++++++++++++++- src/operator/static_operator/dropout_op-inl.h | 1 + .../static_operator/fully_connect_op-inl.h | 88 ++++++----- .../static_operator/fully_connect_op.cc | 19 +-- .../static_operator/fully_connect_op.cu | 8 +- .../static_operator/static_operator_common.h | 37 +++-- 10 files changed, 276 insertions(+), 190 deletions(-) diff --git a/Makefile b/Makefile index 2c9eb787889b..a6da0c4206b0 100644 --- a/Makefile +++ b/Makefile @@ -56,16 +56,16 @@ endif #BIN = test/test_threaded_engine test/api_registry_test BIN = test/api_registry_test -OBJ = storage.o narray_op_cpu.o static_operator.o static_operator_cpu.o +OBJ = storage.o narray_op_cpu.o # add threaded engine after it is done -OBJCXX11 = engine.o narray.o c_api.o registry.o symbol.o operator.o fully_connect_op_cpu.o +OBJCXX11 = engine.o narray.o c_api.o registry.o symbol.o fully_connect_op_cpu.o CUOBJ = SLIB = lib/libmxnet.so ALIB = lib/libmxnet.a LIB_DEP = $(DMLC_CORE)/libdmlc.a ifeq ($(USE_CUDA), 1) - CUOBJ += narray_op_gpu.o static_operator_gpu.o fully_connect_op_gpu.o + CUOBJ += narray_op_gpu.o fully_connect_op_gpu.o endif .PHONY: clean all test lint doc diff --git a/include/mxnet/base.h b/include/mxnet/base.h index 6256947faf5c..388ad8c23e90 100644 --- a/include/mxnet/base.h +++ b/include/mxnet/base.h @@ -41,39 +41,6 @@ typedef mshadow::index_t index_t; /*! \brief data type that will be used to store ndarray */ typedef mshadow::default_real_t real_t; -/*! \brief option to pass into the forward function */ -struct Option { - /*! \brief whether it is training phase*/ - int is_train; -}; -/*! \brief gradient request type the request can have */ -enum GradReqType { - /*! \brief no operation, do not write gradient */ - kNullOp = 0, - /*! \brief write gradient to provided space */ - kWriteTo = 1, - /*! \brief same as kWriteTo, but provided space is same as space of input-data */ - kWriteInplace = 2, - /*! \brief add to the provided space */ - kAddTo = 3 -}; -/*! \brief input argument type of the operator have */ -enum ArgType { - /*! \brief data argument */ - kDataArg = 0, - /*! \brief weight argument */ - kWeightArg = 1, - /*! \brief bias argument */ - kBiasArg = 2 -}; -/*! \brief Property for engine schedule */ -enum Property { - /*! \brief Op contains interanl state, won't influence engine schedule */ - kContainInteralState = 1, - /*! \brief Op forward require random number, will influence engine schedule */ - kForwardRequireRnd = 2, -}; - /*! \brief context information about the execution enviroment */ struct Context { /*! \brief the device type we run the op can be cpu::kDevMask or gpu::kDevMask */ diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h index 29c9691e8ff5..bb718b6f9fdb 100644 --- a/include/mxnet/c_api.h +++ b/include/mxnet/c_api.h @@ -385,6 +385,7 @@ MXNET_DLL int MXOpForward(OperatorHandle op, * \param op the operator handle * \param grad_next array of output gradients * \param in_data array of input narray to the operator + * \param out_data array of output narray to the operator * \param out_grad array to holds the gradient on these input * can be NULL if that position request is kNullOp * \param reqs gradient request type @@ -394,6 +395,7 @@ MXNET_DLL int MXOpForward(OperatorHandle op, MXNET_DLL int MXOpBackward(OperatorHandle op, NArrayHandle *grad_next, NArrayHandle *in_data, + NArrayHandle *out_data, NArrayHandle *out_grad, mx_uint *reqs); diff --git a/include/mxnet/operator.h b/include/mxnet/operator.h index c1a53df61fa9..7fb6e0e895ee 100644 --- a/include/mxnet/operator.h +++ b/include/mxnet/operator.h @@ -10,11 +10,34 @@ #include #include #include "./base.h" -#if DMLC_USE_CXX11 == 1 +#if DMLC_USE_CXX11 #include "./narray.h" #include "./dag_engine.h" #endif +#include "./symbolic.h" + namespace mxnet { +/*! \brief option to pass into the forward function */ +struct Option { + /*! \brief whether it is training phase*/ + int is_train; +}; + +/*! \brief operation request type to Forward and Backward */ +enum OpReqType { + /*! \brief no operation, do not write anything */ + kNullOp, + /*! \brief write gradient to provided space */ + kWriteTo, + /*! + * \brief perform an inplace write, + * Target shares memory with one of input arguments. + * This option only happen when + */ + kWriteInplace, + /*! \brief add to the provided space */ + kAddTo +}; /*! * \brief StaticOperator interface * StaticOperator is a stateful object that can be used to call forward and backprop @@ -29,108 +52,77 @@ class StaticOperator { /*! \brief destructor */ virtual ~StaticOperator() {} /*! - * \brief describe property of op - * \return a bit map in int - */ - virtual int DescribeProperty() const { - // default most of layer only conatin internal state - return kContainInteralState; - } - /*! - * \brief perform a forward operation of StaticOperator, save the output to TBlob - * \param opt option on Forward such as whether this is training phase + * \brief perform a forward operation of StaticOperator, save the output to TBlob. + * \param opt option on Forward such as whether this is training phase. * \param ctx runtime context * \param in_data array of input data, it is const - * \param out_data array of output data, + * \param req the request types of saving operation, can only be kWriteTo or kWriteInplace. + * \param out_data array of output data, pointer is used to indicate that this is holder * the space of TBlob in out_data must be pre-allocated with InferShape + * \sa OpReqType */ virtual void Forward(Option opt, RunContext ctx, const std::vector &in_data, + const std::vector &req, const std::vector &out_data) = 0; /*! - * \brief perform a backward operation of the StaticOperator to get the gradient + * \brief Perform a backward Operation, write gradient to the in_grad. * \param ctx runtime context - * \param grad_next the gradient value we get from output of the StaticOperator - * \param in_data the array of input data - * \param out_grad array of output gradient, there could be three possible TBlob - * in the each element in the array - * \param req request types of the gradient saving operation - * only inplace will change input data - * \sa GradReqType + * \param out_grad the gradient value we get from output of the StaticOperator + * \param in_data the array of input data. + * \param out_data the array of output data. + * \param req request types of the saving operation, can be all types. + * \param in_grad the array of gradient we need to write to. + * \sa OpReqType */ virtual void Backward(RunContext ctx, - const std::vector &grad_next, - const std::vector &in_data, const std::vector &out_grad, - const std::vector &req) = 0; - /*! - * \brief factory function, create a new StaticOperator - * \param type the type of StaticOperator - * \param ctx the context device type of StaticOperator - * \return a pointer of StaticOperator object - */ - static StaticOperator *Create(const char *type, Context ctx); + const std::vector &in_data, + const std::vector &out_data, + const std::vector &req, + const std::vector &in_grad) = 0; }; -#if DMLC_USE_CXX11 == 1 +#if DMLC_USE_CXX11 /*! - * \brief operator interface - * operator is an object can be scheduled by DAG engine directly. + * \brief Operator interface. + * Operator is an object can have Forward and Backward function. * - * This interface relies on NArray. The user should prepare the input NArray and - * output NArray by themselves. - * \sa Operator + * It can be created from */ class Operator { public: /*! \brief destructor */ virtual ~Operator() {} /*! - * \brief describe property of op - * \return a bit map in int - */ - virtual int DescribeProperty() const = 0; - /*! - * \brief perform a forward operation of operator, save the output to NArray - * This method only pushes an execution request to the DAG engine, and - * return immediately. Actual execution is conducted by the DAG engine. - * \param opt option on Forward such as whether this is training phase - * \param ctx runtime context - * \param in_data array of input data, it is const - * \param out_data array of output data, - * the space of NArray in out_data must be pre-allocated with InferShape - * \sa NArray + * \brief Perform a Forward operation of Operator + * After this operation, user can get the result by using function head. */ - virtual void Forward(Option opt, - RunContext ctx, - const std::vector &in_data, - const std::vector &out_data) = 0; + virtual void Forward() = 0; /*! - * \brief perform a backward operation of the operator to get the gradient - * This method only pushes an execution request to the DAG engine, and - * return immediately. Actual execution is conducted by the DAG engine. - * \param ctx runtime context - * \param grad_next the gradient value of the output of the operator, used by chain rule. - * \param in_data the array of input data - * \param out_grad array of output gradient - * \param req request types of the gradient saving operation - * only inplace will change input data - * \sa GradReqType, NArray + * \brief Perform a Backward operation of the Operator. + * This must be called after Forward. + * After this operation, NArrays specified by grad_in_args_store will be updated accordingly. */ - virtual void Backward(RunContext ctx, - const std::vector &grad_next, - const std::vector &in_data, - const std::vector &out_grad, - const std::vector &req) = 0; + virtual void Backward() = 0; + /*! \return get array of heads in the operator */ + virtual const std::vector &head() const = 0; /*! - * \brief Create a wrapper of static operator to wrap it into Operator. - * This function takes ownership of op - * \param op static operator to wrap from - * \param ctx context of the created operator - * \return a wrapper operator + * \brief Create an operator by bind symbol with context and arguments. + * If user do not want to compute the gradients of i-th argument, grad_req_type[i] can be kNullOp. + * + * \param ctx the context of binding. + * \param symbol the symbol that specifies the output of Forward pass. + * \param in_args the NArray that stores the input arguments to the symbol. + * \param grad_in_args_store NArray that is used to store the gradient output of the input arguments. + * \param grad_req_type requirment type of gradient saving. Can only be in {kNullOp, kAddTo, kWriteTo}. */ - static Operator *CreateWrapper(StaticOperator *op, Context ctx); + static Operator *Bind(Symbol symbol, + Context ctx, + const std::vector &in_args, + const std::vector &grad_in_args_store, + const std::vector &grad_req_type); }; // class operator #endif } // namespace mxnet diff --git a/include/mxnet/symbolic.h b/include/mxnet/symbolic.h index 0b3aead32e54..c5a92fb07e35 100644 --- a/include/mxnet/symbolic.h +++ b/include/mxnet/symbolic.h @@ -12,7 +12,8 @@ #include #include #include -#if DMLC_USE_CXX11 == 1 +#include +#if DMLC_USE_CXX11 #include #include #endif @@ -21,6 +22,7 @@ namespace mxnet { // forward declare StaticOperator class StaticOperator; +#if DMLC_USE_CXX11 /*! * \brief AtomicSymbol is the base class of all atomic symbols. * This is not meant to be used by user, it should be wrapped in Symbol, so that the same instance @@ -78,15 +80,129 @@ class AtomicSymbol { * Bind function of AtomicSymbol does not return NArrayOperator, but static operator. * Calling bind from the Symbol wrapper would generate a NArrayOperator. */ - template StaticOperator* Bind(Context ctx) const; /*! * \brief return the type string of the atomic symbol * subclasses override this function. */ virtual std::string TypeString() const = 0; - friend class Symbol; - + /*! + * \brief Declare the input requirement of Backward pass. + * + * Only the returned list of variables will be used in Backward. + * This function is used for memory optimization. + * It is adviced to override and only return what is actually needed. + * If this function is not overriden, all the variables will be valid in Backward. + * + * \code + * // The following code declares Backward need out_grad[0], in_data[0],in_data[1] + * vector BackwardInputs(const vector &out_grad, + * const vector &in_data, + * const vector &out_data) const { + * return {out_grad[0], in_data[0], in_data[1]}; + * } + * \endcode + * \param out_grad gradient of outputs in backward pass. + * \param in_data the input data in forward pass. + * \param out_data the output data in forward pass. + * \return an integer vector indicating the input requirments + * \sa BackwardInputs + */ + virtual std::vector DeclareBackwardDependency( + const std::vector &out_grad, + const std::vector &in_data, + const std::vector &out_data) const { + // By default requires to see all the things. + // remember to override this function to get a better performance. + std::vector ret = out_grad; + ret.insert(ret.end(), in_data.begin(), in_data.end()); + ret.insert(ret.end(), out_data.begin(), out_data.end()); + return ret; + } + /*! + * \brief Get possible forward inplace options. + * This function enables optimization to reuse memory of inputs in output. + * Only override when necessary, by default in-place is disabled. + * + * \code + * // The following code says out_data[0] can share data with in_data[0] + * vector > ForwardInplaceOption(const vector &in_data, + * const vector &out_data) const { + * return {{out_data[0], in_data[0]}}; + * } + * \endcode + * \return list of pair of integers taken from the inputs vector, + * indicating possible in place operations. + */ + virtual std::vector > ForwardInplaceOption( + const std::vector &in_data, + const std::vector &out_data) const { + return std::vector >(); + } + /*! + * \brief Get possible backward inplace options. + * This function enables optimization to reuse memory of inputs in output. + * Only override when necessary, by default in-place is disabled. + * + * \code + * // The following code says in_grad[0] can share data with in_data[0] + * vector > BackwardInplaceOption( + * const std::vector &out_grad, + * const std::vector &in_data, + * const std::vector &out_data, + * const std::vector &in_grad) const { + * return {in_grad[0], in_data[0]}}; + * } + * \endcode + * \return list of pair of integers taken from the inputs vector, + * indicating possible in place operations. + */ + virtual std::vector > BackwardInplaceOption( + const std::vector &out_grad, + const std::vector &in_data, + const std::vector &out_data, + const std::vector &in_grad) const { + return std::vector >(); + } + /*! + * \brief Get Backward Input Dependency for generic types of data. + * Normally T can be pointer of Symbol::DataEntry, or NArray. + * This function will select the result list of T according to DeclareBackwardDependency. + * + * \param in_data the input data in forward pass. + * \param out_data the output data in forward pass. + * \param out_grad gradient of outputs in backward pass. + * \tparam T the generic type parameter. + * \return vector of inputs the Backward Operation depends on. + * \sa DeclareBackwardDependency + */ + template + inline std::vector BackwardInputs(const std::vector &in_data, + const std::vector &out_data, + const std::vector &out_grad) const { + int cnt = 0; + std::vector all_vec; + std::vector in_data_idx, out_data_idx, out_grad_idx; + for (size_t i = 0; i < in_data.size(); ++i) { + in_data_idx.push_back(cnt++); + all_vec.push_back(in_data[i]); + } + for (size_t i = 0; i < out_data.size(); ++i) { + out_data_idx.push_back(cnt++); + all_vec.push_back(out_data[i]); + } + for (size_t i = 0; i < out_grad.size(); ++i) { + out_grad_idx.push_back(cnt++); + all_vec.push_back(out_data[i]); + } + std::vector ret_idx = this->DeclareBackwardDependency( + in_data_idx, out_data_idx, out_grad_idx); + std::vector ret; + for (size_t i = 0; i < ret_idx.size(); ++i) { + ret.push_back(all_vec[ret_idx[i]]); + } + return ret; + } /*! * \brief create atomic symbol by type name * \param type_name the type string of the AtomicSymbol @@ -94,7 +210,7 @@ class AtomicSymbol { */ static AtomicSymbol *Create(const char* type_name); }; -#if DMLC_USE_CXX11 == 1 + /*! * \brief StaticGraph is the configuration of computation graphs. * This is the "configuration file" of mxnet. @@ -162,8 +278,6 @@ class StaticGraph { bool InferShape(std::vector *in_shape, std::vector *out_shape) const; }; -#endif -#if DMLC_USE_CXX11 == 1 /*! * \brief Symbol is used to represent dynamically generated symbolic computation graph. * @@ -422,6 +536,6 @@ class Symbol { */ int FindDuplicateArgs(std::unordered_map *out) const; }; -#endif +#endif // DMLC_USE_CXX11 } // namespace mxnet #endif // MXNET_SYMBOLIC_H_ diff --git a/src/operator/static_operator/dropout_op-inl.h b/src/operator/static_operator/dropout_op-inl.h index b79a79fbea65..23c9f6aab457 100644 --- a/src/operator/static_operator/dropout_op-inl.h +++ b/src/operator/static_operator/dropout_op-inl.h @@ -59,6 +59,7 @@ class DropoutOp : public StaticOperator { virtual void Backward(RunContext ctx, const std::vector &grad_next, const std::vector &in_data, + const std::vector &out_data, const std::vector &out_grad, const std::vector &req) { CHECK_EQ(grad_next.size(), 1); diff --git a/src/operator/static_operator/fully_connect_op-inl.h b/src/operator/static_operator/fully_connect_op-inl.h index d39335deeeff..9bdca812fa20 100644 --- a/src/operator/static_operator/fully_connect_op-inl.h +++ b/src/operator/static_operator/fully_connect_op-inl.h @@ -11,11 +11,17 @@ #include #include #include +#include #include "./static_operator_common.h" #include "./param.h" namespace mxnet { namespace op { +// Declare enumeration of input order to make code more intuitive. +// These enums are only visible within this header +enum FullyConnectOpInputs {kData, kWeight, kBias}; +enum FullyConnectOpOutputs {kOut}; + /** * \brief This is the implementation of fully connected layer. * @@ -34,55 +40,58 @@ class FullyConnectOp : public StaticOperator { virtual void Forward(Option opt, RunContext ctx, const std::vector &in_data, + const std::vector &req, const std::vector &out_data) { using namespace mshadow; using namespace mshadow::expr; + CHECK_EQ(req[kOut], kWriteTo); size_t expected = param_.no_bias == 0 ? 3 : 2; CHECK_EQ(in_data.size(), expected); CHECK_EQ(out_data.size(), 1); // TODO(bing): check the BLAS Handle, be careful // maybe need blas handle from context Stream *s = static_cast *>(ctx.stream); - Tensor data = in_data[0].FlatTo2D(s); - Tensor wmat = in_data[1].get(s); - Tensor out = out_data[0].FlatTo2D(s); + Tensor data = in_data[kData].FlatTo2D(s); + Tensor wmat = in_data[kWeight].get(s); + Tensor out = out_data[kOut].FlatTo2D(s); out = dot(data, wmat.T()); if (param_.no_bias == 0) { - Tensor bias = in_data[2].get(s); + Tensor bias = in_data[kBias].get(s); out += repmat(bias, data.size(0)); } } virtual void Backward(RunContext ctx, - const std::vector &grad_next, - const std::vector &in_data, const std::vector &out_grad, - const std::vector &req) { + const std::vector &in_data, + const std::vector &out_data, + const std::vector &req, + const std::vector &in_grad) { using namespace mshadow; using namespace mshadow::expr; - CHECK_EQ(grad_next.size(), 1); + CHECK_EQ(out_grad.size(), 1); size_t expected = param_.no_bias == 0 ? 3 : 2; - CHECK(in_data.size() == expected && out_grad.size() == expected); + CHECK(in_data.size() == expected && in_grad.size() == expected); CHECK_EQ(req.size(), expected); // TODO(bing): check the BLAS Handle, be careful // maybe need blas handle from context Stream *s = static_cast *>(ctx.stream); - Tensor data = in_data[0].FlatTo2D(s); - Tensor wmat = in_data[1].get(s); - Tensor grad = grad_next[0].FlatTo2D(s); + Tensor data = in_data[kData].FlatTo2D(s); + Tensor wmat = in_data[kWeight].get(s); + Tensor grad = out_grad[kOut].FlatTo2D(s); // backprop - CHECK_NE(req[1], kWriteInplace) << "cannot write weight inplace"; + CHECK_NE(req[kWeight], kWriteInplace) << "cannot write weight inplace"; // gradient of weight - Tensor gwmat = out_grad[1].get(s); - Assign(gwmat, req[1], dot(grad.T(), data)); + Tensor gwmat = in_grad[kWeight].get(s); + Assign(gwmat, req[kWeight], dot(grad.T(), data)); // gradient of bias if (param_.no_bias == 0) { - Tensor gbias = out_grad[2].get(s); - Assign(gbias, req[2], sum_rows(grad)); + Tensor gbias = in_grad[kBias].get(s); + Assign(gbias, req[kBias], sum_rows(grad)); } // gradient of data - Tensor gdata = out_grad[0].FlatTo2D(s); - Assign(gdata, req[0], dot(grad, wmat)); + Tensor gdata = in_grad[kData].FlatTo2D(s); + Assign(gdata, req[kData], dot(grad, wmat)); } private: @@ -90,17 +99,21 @@ class FullyConnectOp : public StaticOperator { Param param_; }; // class FullyConnectOp +// Decalre factory function, used for dispatch specialization +template +StaticOperator* CreateFullyConnectedOp(Param param); + +#if DMLC_USE_CXX11 /** * @brief The symbol part of the fully connected layer. */ class FullyConnectSymbol : public AtomicSymbol { public: virtual std::vector ListArguments() const { - std::string ret[] = {"data", "weight", "bias"}; if (param_.no_bias == 0) { - return std::vector(ret, ret + 3); + return {"data", "weight", "bias"}; } else { - return std::vector(ret, ret + 2); + return {"data", "weight"}; } } @@ -121,9 +134,9 @@ class FullyConnectSymbol : public AtomicSymbol { CHECK_EQ(dshape.ndim(), 4) << \ "Input data should be 4D in batch-1-1-hidden"; CHECK_NE(dshape.ndim(), 0) << "Require data shape to be known"; - ShapeAssignCheck((*in_shape)[1], Shape2(param_.num_hidden, dshape[3])); + ShapeAssignCheck((*in_shape)[kWeight], Shape2(param_.num_hidden, dshape[3])); if (param_.no_bias == 0) { - ShapeAssignCheck((*in_shape)[2], Shape1(param_.num_hidden)); + ShapeAssignCheck((*in_shape)[kBias], Shape1(param_.num_hidden)); } out_shape->clear(); out_shape->push_back(dshape); @@ -140,23 +153,30 @@ class FullyConnectSymbol : public AtomicSymbol { virtual std::string TypeString() const { return "FullyConnected"; } + // decalre dependency and inplace optimization options + virtual std::vector DeclareBackwardDependency( + const std::vector &out_grad, + const std::vector &in_data, + const std::vector &out_data) const { + return {out_grad[kOut], in_data[kData], in_data[kWeight]}; + } - /** - * @brief This is the template function of bind() implementation. - * - * @param ctx The device context - * @return A device dependent static operator can be used for execution. - */ - template - StaticOperator* Bind_(Context ctx) const; - // the real bind + virtual std::vector > BackwardInplaceOption( + const std::vector &out_grad, + const std::vector &in_data, + const std::vector &out_data, + const std::vector &in_grad) const { + return {{in_grad[kData], in_data[kData]}}; + } + + // bind function StaticOperator* Bind(Context ctx) const; private: /** The param of the fully connected layer.*/ Param param_; }; // class FullyConnectSymbol - +#endif } // namespace op } // namespace mxnet diff --git a/src/operator/static_operator/fully_connect_op.cc b/src/operator/static_operator/fully_connect_op.cc index 9f3cad3292b0..69687024384e 100644 --- a/src/operator/static_operator/fully_connect_op.cc +++ b/src/operator/static_operator/fully_connect_op.cc @@ -8,26 +8,15 @@ namespace mxnet { namespace op { template<> -StaticOperator* FullyConnectSymbol::Bind_(Context ctx) const { - return new FullyConnectOp(param_); +StaticOperator* CreateFullyConnectedOp(Param param) { + return new FullyConnectOp(param); } -// put this after the template specialization +// DO_BIND_DISPATCH comes from static_operator_common.h StaticOperator* FullyConnectSymbol::Bind(Context ctx) const { - if (ctx.dev_mask == cpu::kDevMask) { - return Bind_(ctx); - } else { - #if MXNET_USE_CUDA - return Bind_(ctx); - #else - LOG(FATAL) << "GPU is not enabled"; - return NULL; - #endif - } + DO_BIND_DISPATCH(CreateFullyConnectedOp, param_); } -// register the symbol REGISTER_ATOMIC_SYMBOL(FullyConnected, FullyConnectSymbol); - } // namespace op } // namespace mxnet diff --git a/src/operator/static_operator/fully_connect_op.cu b/src/operator/static_operator/fully_connect_op.cu index 8e3efbcaddfd..2ff5b565ee88 100644 --- a/src/operator/static_operator/fully_connect_op.cu +++ b/src/operator/static_operator/fully_connect_op.cu @@ -3,12 +3,14 @@ * \file fully_connect_sym.cu * \brief fully connect operator symbol */ -#include "../static_operator/fully_connect_op-inl.h" +#include "./fully_connect_op-inl.h" namespace mxnet { namespace op { + template<> -StaticOperator* FullyConnectSymbol::Bind_(Context ctx) const { - return new FullyConnectOp(param_); +StaticOperator* CreateFullyConnectedOp(Param param) { + return new FullyConnectOp(param); } + } // namespace op } // namespace mxnet diff --git a/src/operator/static_operator/static_operator_common.h b/src/operator/static_operator/static_operator_common.h index 06eb307b8ca0..f90b9ffd6ce3 100644 --- a/src/operator/static_operator/static_operator_common.h +++ b/src/operator/static_operator/static_operator_common.h @@ -24,7 +24,7 @@ namespace op { */ template inline void Assign(OType &out, // NOLINT(*) - GradReqType req, + OpReqType req, const Exp &exp) { switch (req) { case kNullOp: break; @@ -49,25 +49,24 @@ inline void ShapeAssignCheck(TShape &out, const TS &shape) { // NOLINT(*) } } -/*! \brief type of operators */ -enum OpType { - kReLU = 0, - kFullc = 1, - kConv = 2, - kMaxPooling = 3, - kAvgPooling = 4, - kSumPooling = 5, - kFlatten = 6, - kReshape = 7, - kDropout = 8, -}; +// definition of micro +#if MXNET_USE_CUDA +#define DO_BIND_DISPATCH(Method, ...) \ + if (ctx.dev_mask == cpu::kDevMask) { \ + return Method(__VA_ARGS__); \ + } else { \ + return Method(__VA_ARGS__); \ + } +#else +#define DO_BIND_DISPATCH(Method, ...) \ + if (ctx.dev_mask == cpu::kDevMask) { \ + return Method(__VA_ARGS__); \ + } else { \ + LOG(FATAL) << "GPU is not enabled"; \ + return nullptr; \ + } +#endif -/*! - * \brief device invariant function to create operators - * \param type the type of operator - */ -template -StaticOperator *CreateOperator(OpType type); } // namespace op } // namespace mxnet #endif // MXNET_OPERATOR_STATIC_OPERATOR_STATIC_OPERATOR_COMMON_H_