From 3ae48a41c8e156a2889460ffa9b628bef61d48c8 Mon Sep 17 00:00:00 2001 From: winsty Date: Sat, 18 Jul 2015 10:53:08 +0800 Subject: [PATCH 1/7] fc symbol --- include/mxnet/atomic_symbol.h | 4 +- include/mxnet/operator.h | 28 ------- src/static_operator/fully_connect_op-inl.h | 33 ++------ src/symbol/fully_connect_sym-inl.h | 87 ++++++++++++++++++++++ 4 files changed, 97 insertions(+), 55 deletions(-) create mode 100644 src/symbol/fully_connect_sym-inl.h diff --git a/include/mxnet/atomic_symbol.h b/include/mxnet/atomic_symbol.h index 8d4019f87d62..2a61f0da88de 100644 --- a/include/mxnet/atomic_symbol.h +++ b/include/mxnet/atomic_symbol.h @@ -14,7 +14,7 @@ #include "./tensor_blob.h" namespace mxnet { -class Operator; +class StaticOperator; /*! * \brief AtomicSymbol is the base class of all atomic symbols. * This is not meant to be used by user, it should be wrapped in Symbol, so that the same instance @@ -66,7 +66,7 @@ class AtomicSymbol { * Bind function of AtomicSymbol does not return NArrayOperator, but static operator. * Calling bind from the Symbol wrapper would generate a NArrayOperator. */ - virtual Operator* Bind(Context ctx) const = 0; + virtual StaticOperator* Bind(Context ctx) const = 0; /*! * \brief return the type string of the atomic symbol * subclasses override this function. diff --git a/include/mxnet/operator.h b/include/mxnet/operator.h index 4dbcea798cdd..2c0245215151 100644 --- a/include/mxnet/operator.h +++ b/include/mxnet/operator.h @@ -33,39 +33,11 @@ class Operator { * \param ctx Context of the Operator */ Operator(StaticOperator* op, Context ctx); - /*! - * \brief get types of input argument of this oeprator - * \return a vector corresponding to type of each argument - * this order is same as the order of inputs in Forward, InferShape and Backward - */ - virtual std::vector DescribeArgs() const; /*! * \brief describe property of op * \return a bit map in int */ virtual int DescribeProperty() const; - /*! - * \brief set param for the operator from string - * \param name parameter name - * \param val string for configuration - */ - virtual void SetParam(const char *name, const char *val); - /*! - * \brief inter the shapes of outputs and unknown input arguments - * \param in_shape the shape of input arguments of the operator - * this should be of same length as the vector returned by DescribeArgs - * in_shape allows unknown elements, which are checked by shape.ndim() == 0. - * For unknown shapes, InferShape will try to fill in the correct Shape in in_shape - * For known shapes, InferShape will check shape consistency - * - * common practice: set the shape of data input, and usually weight's shape can be infered - * - * \param out_shape the shape of outputs of the operator - * InferShape will modify the vector to fill output TShape - */ - virtual void InferShape(std::vector *in_shape, - std::vector *out_shape); - /*! * \brief set the context of the Operator * \param ctx the context to be set to diff --git a/src/static_operator/fully_connect_op-inl.h b/src/static_operator/fully_connect_op-inl.h index de0250f101bd..d7abd0c5fc10 100644 --- a/src/static_operator/fully_connect_op-inl.h +++ b/src/static_operator/fully_connect_op-inl.h @@ -18,6 +18,14 @@ namespace op { template class FullyConnectOp : public StaticOperator { public: + FullyConnectOp () { + // Do nothing. + } + + FullyConnectOp (param p) { + this->param = p; + } + virtual std::vector DescribeArgs() const { ArgType ret[] = {kDataArg, kWeightArg, kBiasArg}; if (param_.no_bias == 0) { @@ -26,30 +34,6 @@ class FullyConnectOp : public StaticOperator { return std::vector(ret, ret + 2); } } - virtual void SetParam(const char *name, const char *val) { - param_.SetParam(name, val); - } - virtual void InferShape(std::vector *in_shape, - std::vector *out_shape) { - using namespace mshadow; - if (param_.no_bias == 0) { - CHECK_EQ(in_shape->size(), 3) << "Input:[data, weight, bias]"; - } else { - CHECK_EQ(in_shape->size(), 2) << "Input:[data, weight]"; - } - CHECK_GT(param_.num_hidden, 0); - const TShape &dshape = (*in_shape)[0]; - CHECK_EQ(dshape.ndim(), 4) << \ - "Input data should be 4D in batch-1-1-hidden"; - CHECK_NE(dshape.ndim(), 0) << "Require data shape to be known"; - ShapeAssignCheck((*in_shape)[1], Shape2(param_.num_hidden, dshape[3])); - if (param_.no_bias == 0) { - ShapeAssignCheck((*in_shape)[2], Shape1(param_.num_hidden)); - } - out_shape->clear(); - out_shape->push_back(dshape); - (*out_shape)[0][3] = param_.num_hidden; - } virtual void Forward(Option opt, RunContext ctx, const std::vector &in_data, @@ -102,7 +86,6 @@ class FullyConnectOp : public StaticOperator { Tensor gdata = out_grad[0].FlatTo2D(s); Assign(gdata, req[0], dot(grad, wmat)); } - private: Param param_; }; // class FullyConnectOp diff --git a/src/symbol/fully_connect_sym-inl.h b/src/symbol/fully_connect_sym-inl.h new file mode 100644 index 000000000000..0fcaeb463385 --- /dev/null +++ b/src/symbol/fully_connect_sym-inl.h @@ -0,0 +1,87 @@ + /*! + * Copyright (c) 2015 by Contributors + * \file fully_connect_op-inl.h + * \brief fully connect operator + * \author Bing Xu +*/ +#ifndef MXNET_STATIC_OPERATOR_FULLY_CONNECT_OP_INL_H_ +#define MXNET_STATIC_OPERATOR_FULLY_CONNECT_OP_INL_H_ + +#include +#include +#include +#include "./static_operator_common.h" +#include "./param.h" +#include "../static_operator/fully_connect_op-inl.h" + +namespace mxnet { +template +class FullyConnectSymbol : public AtomicSymbol { + public: + virtual std::vector DescribeArguments() const { + std::string ret[] = {"data", "weight", "bias"}; + if (param_.no_bias == 0) { + return std::vector(ret, ret + 3); + } else { + return std::vector(ret, ret + 2); + } + } + + virtual std::vector DescribeReturns() const { + return std::vector(); + } + + virtual void SetParam(const char *name, const char *val) const { + param_.SetParam(name, val); + } + virtual void InferShape(std::vector *in_shape, + std::vector *out_shape) const { + using namespace mshadow; + if (param_.no_bias == 0) { + CHECK_EQ(in_shape->size(), 3) << "Input:[data, weight, bias]"; + } else { + CHECK_EQ(in_shape->size(), 2) << "Input:[data, weight]"; + } + CHECK_GT(param_.num_hidden, 0); + const TShape &dshape = (*in_shape)[0]; + CHECK_EQ(dshape.ndim(), 4) << \ + "Input data should be 4D in batch-1-1-hidden"; + CHECK_NE(dshape.ndim(), 0) << "Require data shape to be known"; + ShapeAssignCheck((*in_shape)[1], Shape2(param_.num_hidden, dshape[3])); + if (param_.no_bias == 0) { + ShapeAssignCheck((*in_shape)[2], Shape1(param_.num_hidden)); + } + out_shape->clear(); + out_shape->push_back(dshape); + (*out_shape)[0][3] = param_.num_hidden; + } + + /*! + * \brief Copy this AtomicSymbol and returns a pointer to the copied object. + * this is a virtual function because different subclass of AtomicSymbol would copy differently. + * \return a pointer to the copied atomic symbol + */ + virtual AtomicSymbol* Copy() const { + FullyConnectSymbol* fc_sym = new FullyConnectSymbol(); + fc_sym->param = this->param; + return fc_sym; + } + /*! + * \brief Bind this AtomicSymbol to a context and get back a static operator + * Bind function of AtomicSymbol does not return Operator, but static operator. + * Calling bind from the Symbol wrapper would generate a Operator. + */ + virtual StaticOperator* Bind(Context ctx) const { + return new FullyConnectOp(param_); + } + + virtual std::string TypeString() const { + return "Fully Connected"; + } + private: + Param param_; +}; // class FullyConnectSymbol +} // namespace mxnet + +#endif // MXNET_STATIC_OPERATOR_FULLY_CONNECT_OP_INL_H_ + From 95116e9a08e6ba3b518a38ad8ea56960a5510cd0 Mon Sep 17 00:00:00 2001 From: winsty Date: Mon, 20 Jul 2015 14:52:34 +0800 Subject: [PATCH 2/7] static op wrapper --- Makefile | 2 +- include/mxnet/operator.h | 14 +--- include/mxnet/static_operator.h | 30 ------- include/mxnet/static_operator_wrapper.h | 81 +++++++++++++++++++ ...operator.cc => static_operator_wrapper.cc} | 46 ++--------- src/static_operator/fully_connect_op-inl.h | 2 +- 6 files changed, 92 insertions(+), 83 deletions(-) create mode 100644 include/mxnet/static_operator_wrapper.h rename src/operator/{operator.cc => static_operator_wrapper.cc} (64%) diff --git a/Makefile b/Makefile index 33dc31fdddbc..39e88f0b834b 100644 --- a/Makefile +++ b/Makefile @@ -87,7 +87,7 @@ static_operator_gpu.o: src/static_operator/static_operator_gpu.cu symbol.o: src/symbol/symbol.cc registry.o: src/registry.cc mxnet_api.o: api/mxnet_api.cc -operator.o: src/operator/operator.cc +operator.o: src/operator/static_operator_wrapper.cc api/libmxnet.a: $(OBJ) $(OBJCXX11) $(CUOBJ) api/libmxnet.so: $(OBJ) $(OBJCXX11) $(CUOBJ) diff --git a/include/mxnet/operator.h b/include/mxnet/operator.h index 2c0245215151..666a36221a92 100644 --- a/include/mxnet/operator.h +++ b/include/mxnet/operator.h @@ -27,12 +27,6 @@ namespace mxnet { */ class Operator { public: - /*! - * \brief construct Operator from StaticOperator and Context - * \param op StaticOperator to wrap - * \param ctx Context of the Operator - */ - Operator(StaticOperator* op, Context ctx); /*! * \brief describe property of op * \return a bit map in int @@ -54,7 +48,7 @@ class Operator { virtual void Forward(Option opt, RunContext ctx, const std::vector &in_data, - const std::vector &out_data); + const std::vector &out_data) = 0; /*! * \brief perform a backward operation of the operator to get the gradient * \param ctx runtime context @@ -70,11 +64,9 @@ class Operator { const std::vector &grad_next, const std::vector &in_data, const std::vector &out_grad, - const std::vector &req); + const std::vector &req) = 0; - private: - /* \brief the static operator */ - StaticOperator* op; + protected: Context global_ctx; }; } // namespace mxnet diff --git a/include/mxnet/static_operator.h b/include/mxnet/static_operator.h index b2f9e49af154..b1ae8fd23de1 100644 --- a/include/mxnet/static_operator.h +++ b/include/mxnet/static_operator.h @@ -23,15 +23,6 @@ namespace mxnet { */ class StaticOperator { public: - /*! - * \brief get types of input argument of this oeprator - * \return a vector corresponding to type of each argument - * this order is same as the order of inputs in Forward, InferShape and Backward - */ - virtual std::vector DescribeArgs() const { - // default most of layers only have one data argument - return std::vector(1, kDataArg); - } /*! * \brief describe property of op * \return a bit map in int @@ -40,27 +31,6 @@ class StaticOperator { // default most of layer only conatin internal state return kContainInteralState; } - /*! - * \brief set param for the StaticOperator from string - * \param name parameter name - * \param val string for configuration - */ - virtual void SetParam(const char *name, const char *val) {} - /*! - * \brief inter the shapes of outputs and unknown input arguments - * \param in_shape the shape of input arguments of the StaticOperator - * this should be of same length as the vector returned by DescribeArgs - * in_shape allows unknown elements, which are checked by shape.ndim() == 0. - * For unknown shapes, InferShape will try to fill in the correct Shape in in_shape - * For known shapes, InferShape will check shape consistency - * - * common practice: set the shape of data input, and usually weight's shape can be infered - * - * \param out_shape the shape of outputs of the StaticOperator - * InferShape will modify the vector to fill output TShape - */ - virtual void InferShape(std::vector *in_shape, - std::vector *out_shape) = 0; /*! * \brief perform a forward operation of StaticOperator, save the output to TBlob * \param opt option on Forward such as whether this is training phase diff --git a/include/mxnet/static_operator_wrapper.h b/include/mxnet/static_operator_wrapper.h new file mode 100644 index 000000000000..6544269a3738 --- /dev/null +++ b/include/mxnet/static_operator_wrapper.h @@ -0,0 +1,81 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file static_operator_wrapper.h + * \brief operator interface of mxnet + * \author Naiyan Wang + */ +#ifndef MXNET_STATIC_OPERATOR_WRAPPER_H_ +#define MXNET_STATIC_OPERATOR_WRAPPER_H_ +// this file will be seen by cuda, no c++11 for now +#include +#include +#include "./base.h" +#include "./tensor_blob.h" +#include "./static_operator.h" +#include "./operator.h" +#include "./narray.h" +#include "./dag_engine.h" + +namespace mxnet { +/*! + * \brief static operator interface (current interface have not yet todo with scheduler), + * operator is a stateful object that can be used to call forward and backprop + * + * This interface relies on pre-allocated memory in TBlob, the caller need to set + * the memory region in TBlob correctly before calling Forward and Backward + * + * \sa Operator + */ +class StaticOperatorWrapper: public Operator { + public: + /*! + * \brief construct Operator from StaticOperator and Context + * \param op StaticOperator to wrap + * \param ctx Context of the Operator + */ + StaticOperatorWrapper(StaticOperator* op, Context ctx); + /*! + * \brief describe property of op + * \return a bit map in int + */ + virtual int DescribeProperty() const; + /*! + * \brief set the context of the Operator + * \param ctx the context to be set to + */ + virtual void SetContext(Context ctx); + /*! + * \brief perform a forward operation of operator, save the output to TBlob + * \param opt option on Forward such as whether this is training phase + * \param ctx runtime context + * \param in_data array of input data, it is const + * \param out_data array of output data, + * the space of TBlob in out_data must be pre-allocated with InferShape + */ + virtual void Forward(Option opt, + RunContext ctx, + const std::vector &in_data, + const std::vector &out_data); + /*! + * \brief perform a backward operation of the operator to get the gradient + * \param ctx runtime context + * \param grad_next the gradient value we get from output of the operator + * \param in_data the array of input data + * \param out_grad array of output gradient, there could be three possible TBlob + * in the each element in the array + * \param req request types of the gradient saving operation + * only inplace will change input data + * \sa GradReqType + */ + virtual void Backward(RunContext ctx, + const std::vector &grad_next, + const std::vector &in_data, + const std::vector &out_grad, + const std::vector &req); + + private: + /* \brief the static operator */ + StaticOperator* op; +}; +} // namespace mxnet +#endif // MXNET_STATIC_OPERATOR_WRAPPER_H_ diff --git a/src/operator/operator.cc b/src/operator/static_operator_wrapper.cc similarity index 64% rename from src/operator/operator.cc rename to src/operator/static_operator_wrapper.cc index ccfc640d3a3d..07834739ae0b 100644 --- a/src/operator/operator.cc +++ b/src/operator/static_operator_wrapper.cc @@ -4,56 +4,22 @@ * \brief the implementation of narray operator * \author Naiyan Wang */ -#include +#include namespace mxnet { - Operator::Operator(StaticOperator* op, Context ctx) { + StaticOperatorWrapper::StaticOperatorWrapper(StaticOperator* op, Context ctx) { this->op = op; this->global_ctx = ctx; } - /*! - * \brief get types of input argument of this oeprator - * \return a vector corresponding to type of each argument - * this order is same as the order of inputs in Forward, InferShape and Backward - */ - std::vector Operator::DescribeArgs() const { - // default most of layers only have one data argument - return op->DescribeArgs(); - } /*! * \brief describe property of op * \return a bit map in int */ - int Operator::DescribeProperty() const { + int StaticOperatorWrapper::DescribeProperty() const { // default most of layer only conatin internal state return op->DescribeProperty(); } - /*! - * \brief set param for the operator from string - * \param name parameter name - * \param val string for configuration - */ - void Operator::SetParam(const char *name, const char *val) { - op->SetParam(name, val); - } - /*! - * \brief inter the shapes of outputs and unknown input arguments - * \param in_shape the shape of input arguments of the operator - * this should be of same length as the vector returned by DescribeArgs - * in_shape allows unknown elements, which are checked by shape.ndim() == 0. - * For unknown shapes, InferShape will try to fill in the correct Shape in in_shape - * For known shapes, InferShape will check shape consistency - * - * common practice: set the shape of data input, and usually weight's shape can be infered - * - * \param out_shape the shape of outputs of the operator - * InferShape will modify the vector to fill output TShape - */ - void Operator::InferShape(std::vector *in_shape, - std::vector *out_shape) { - op->InferShape(in_shape, out_shape); - } /*! * \brief perform a forward operation of operator, save the output to TBlob * \param opt option on Forward such as whether this is training phase @@ -62,7 +28,7 @@ namespace mxnet { * \param out_data array of output data, * the space of TBlob in out_data must be pre-allocated with InferShape */ - void Operator::Forward(Option opt, + void StaticOperatorWrapper::Forward(Option opt, RunContext ctx, const std::vector &in_data, const std::vector &out_data) { @@ -93,7 +59,7 @@ namespace mxnet { * only inplace will change input data * \sa GradReqType */ - void Operator::Backward(RunContext ctx, + void StaticOperatorWrapper::Backward(RunContext ctx, const std::vector &grad_next, const std::vector &in_data, const std::vector &out_grad, @@ -120,7 +86,7 @@ namespace mxnet { }, global_ctx, used_var, mutate_var); } - void Operator::SetContext(Context ctx) { + void StaticOperatorWrapper::SetContext(Context ctx) { this->global_ctx = ctx; } diff --git a/src/static_operator/fully_connect_op-inl.h b/src/static_operator/fully_connect_op-inl.h index d7abd0c5fc10..74987255dead 100644 --- a/src/static_operator/fully_connect_op-inl.h +++ b/src/static_operator/fully_connect_op-inl.h @@ -22,7 +22,7 @@ class FullyConnectOp : public StaticOperator { // Do nothing. } - FullyConnectOp (param p) { + FullyConnectOp (Param p) { this->param = p; } From b8be6c3ad91389def813a1e59f25f8c2e9eb9dc1 Mon Sep 17 00:00:00 2001 From: winsty Date: Mon, 20 Jul 2015 16:48:35 +0800 Subject: [PATCH 3/7] fix fc sym --- Makefile | 6 ++++-- src/symbol/fully_connect_sym-inl.h | 24 +++++++++++++++--------- 2 files changed, 19 insertions(+), 11 deletions(-) diff --git a/Makefile b/Makefile index 39e88f0b834b..fb7ac440a182 100644 --- a/Makefile +++ b/Makefile @@ -56,7 +56,7 @@ endif #BIN = test/test_threaded_engine test/api_registry_test BIN = test/api_registry_test -OBJ = storage.o narray_op_cpu.o static_operator.o static_operator_cpu.o +OBJ = storage.o narray_op_cpu.o static_operator.o static_operator_cpu.o atomic_symbol_cpu.o # add threaded engine after it is done OBJCXX11 = engine.o narray.o mxnet_api.o registry.o symbol.o operator.o CUOBJ = @@ -65,7 +65,7 @@ ALIB = api/libmxnet.a LIB_DEP = $(DMLC_CORE)/libdmlc.a ifeq ($(USE_CUDA), 1) - CUOBJ += narray_op_gpu.o static_operator_gpu.o + CUOBJ += narray_op_gpu.o static_operator_gpu.o atomic_symbol_gpu.o endif .PHONY: clean all test lint doc @@ -88,6 +88,8 @@ symbol.o: src/symbol/symbol.cc registry.o: src/registry.cc mxnet_api.o: api/mxnet_api.cc operator.o: src/operator/static_operator_wrapper.cc +atomic_symbol_cpu.o: src/symbol/fully_connect_sym.cc +atomic_symbol_gpu.o: src/symbol/fully_connect_sym.cu api/libmxnet.a: $(OBJ) $(OBJCXX11) $(CUOBJ) api/libmxnet.so: $(OBJ) $(OBJCXX11) $(CUOBJ) diff --git a/src/symbol/fully_connect_sym-inl.h b/src/symbol/fully_connect_sym-inl.h index 0fcaeb463385..d2443389dead 100644 --- a/src/symbol/fully_connect_sym-inl.h +++ b/src/symbol/fully_connect_sym-inl.h @@ -4,17 +4,19 @@ * \brief fully connect operator * \author Bing Xu */ -#ifndef MXNET_STATIC_OPERATOR_FULLY_CONNECT_OP_INL_H_ -#define MXNET_STATIC_OPERATOR_FULLY_CONNECT_OP_INL_H_ +#ifndef MXNET_SYMBOL_FULLY_CONNECT_SYM_INL_H_ +#define MXNET_SYMBOL_FULLY_CONNECT_SYM_INL_H_ #include #include +#include #include -#include "./static_operator_common.h" -#include "./param.h" #include "../static_operator/fully_connect_op-inl.h" +#include "../static_operator/param.h" namespace mxnet { +using namespace mxnet::op; + template class FullyConnectSymbol : public AtomicSymbol { public: @@ -28,12 +30,16 @@ class FullyConnectSymbol : public AtomicSymbol { } virtual std::vector DescribeReturns() const { - return std::vector(); + std::string temp = "output"; + std::vector v; + v.push_back(temp); + return v; } - virtual void SetParam(const char *name, const char *val) const { + virtual void SetParam(const char *name, const char *val) { param_.SetParam(name, val); } + virtual void InferShape(std::vector *in_shape, std::vector *out_shape) const { using namespace mshadow; @@ -71,8 +77,8 @@ class FullyConnectSymbol : public AtomicSymbol { * Bind function of AtomicSymbol does not return Operator, but static operator. * Calling bind from the Symbol wrapper would generate a Operator. */ - virtual StaticOperator* Bind(Context ctx) const { - return new FullyConnectOp(param_); + virtual StaticOperator* Bind_(Context ctx) const { + } virtual std::string TypeString() const { @@ -83,5 +89,5 @@ class FullyConnectSymbol : public AtomicSymbol { }; // class FullyConnectSymbol } // namespace mxnet -#endif // MXNET_STATIC_OPERATOR_FULLY_CONNECT_OP_INL_H_ +#endif // MXNET_SYMBOL_FULLY_CONNECT_SYM_INL_H_ From 05263cc1ba37741f1828fcd8fdcc165516ef23f4 Mon Sep 17 00:00:00 2001 From: winsty Date: Mon, 20 Jul 2015 17:03:39 +0800 Subject: [PATCH 4/7] bind --- src/symbol/fully_connect_sym-inl.h | 5 +++-- src/symbol/fully_connect_sym.cc | 13 +++++++++++++ src/symbol/fully_connect_sym.cu | 13 +++++++++++++ 3 files changed, 29 insertions(+), 2 deletions(-) create mode 100644 src/symbol/fully_connect_sym.cc create mode 100644 src/symbol/fully_connect_sym.cu diff --git a/src/symbol/fully_connect_sym-inl.h b/src/symbol/fully_connect_sym-inl.h index d2443389dead..e3f60c945448 100644 --- a/src/symbol/fully_connect_sym-inl.h +++ b/src/symbol/fully_connect_sym-inl.h @@ -39,7 +39,7 @@ class FullyConnectSymbol : public AtomicSymbol { virtual void SetParam(const char *name, const char *val) { param_.SetParam(name, val); } - + virtual void InferShape(std::vector *in_shape, std::vector *out_shape) const { using namespace mshadow; @@ -72,13 +72,14 @@ class FullyConnectSymbol : public AtomicSymbol { fc_sym->param = this->param; return fc_sym; } + /*! * \brief Bind this AtomicSymbol to a context and get back a static operator * Bind function of AtomicSymbol does not return Operator, but static operator. * Calling bind from the Symbol wrapper would generate a Operator. */ virtual StaticOperator* Bind_(Context ctx) const { - + return new FullyConnectSymbol(param_); } virtual std::string TypeString() const { diff --git a/src/symbol/fully_connect_sym.cc b/src/symbol/fully_connect_sym.cc new file mode 100644 index 000000000000..445b7357420f --- /dev/null +++ b/src/symbol/fully_connect_sym.cc @@ -0,0 +1,13 @@ + /*! + * Copyright (c) 2015 by Contributors + * \file fully_connect_sym.cc + * \brief fully connect operator symbol +*/ +#include "./fully_connect_sym-inl.h" + +namespace mxnet { + template <> + StaticOperator* FullyConnectSymbol::Bind(Context ctx) const { + return Bind_(ctx); + } +} diff --git a/src/symbol/fully_connect_sym.cu b/src/symbol/fully_connect_sym.cu new file mode 100644 index 000000000000..d94336e376b7 --- /dev/null +++ b/src/symbol/fully_connect_sym.cu @@ -0,0 +1,13 @@ + /*! + * Copyright (c) 2015 by Contributors + * \file fully_connect_sym.cu + * \brief fully connect operator symbol +*/ +#include "./fully_connect_sym-inl.h" + +namespace mxnet { + template <> + StaticOperator* FullyConnectSymbol::Bind(Context ctx) const { + return Bind_(ctx); + } +} From b5c15e26562676ba5c8484787605ab9de36411f8 Mon Sep 17 00:00:00 2001 From: winsty Date: Mon, 20 Jul 2015 17:47:08 +0800 Subject: [PATCH 5/7] fix bind --- include/mxnet/atomic_symbol.h | 5 +- src/static_operator/fully_connect_op-inl.h | 2 +- src/symbol/fully_connect_sym-inl.h | 61 +++++---------------- src/symbol/fully_connect_sym.cc | 63 ++++++++++++++++++++-- src/symbol/fully_connect_sym.cu | 2 +- 5 files changed, 77 insertions(+), 56 deletions(-) diff --git a/include/mxnet/atomic_symbol.h b/include/mxnet/atomic_symbol.h index 2a61f0da88de..842fbe6d0198 100644 --- a/include/mxnet/atomic_symbol.h +++ b/include/mxnet/atomic_symbol.h @@ -54,7 +54,7 @@ class AtomicSymbol { * InferShape will modify the vector to fill output TShape * \return if the shape inference is successful, return true, else return false. */ - virtual bool InferShape(std::vector *in_shape, std::vector *out_shape) = 0; + virtual bool InferShape(std::vector *in_shape, std::vector *out_shape) const = 0; /*! * \brief Copy this AtomicSymbol and returns a pointer to the copied object. * this is a virtual function because different subclass of AtomicSymbol would copy differently. @@ -66,7 +66,8 @@ class AtomicSymbol { * Bind function of AtomicSymbol does not return NArrayOperator, but static operator. * Calling bind from the Symbol wrapper would generate a NArrayOperator. */ - virtual StaticOperator* Bind(Context ctx) const = 0; + template + StaticOperator* Bind(Context ctx) const; /*! * \brief return the type string of the atomic symbol * subclasses override this function. diff --git a/src/static_operator/fully_connect_op-inl.h b/src/static_operator/fully_connect_op-inl.h index 74987255dead..b3a00151574f 100644 --- a/src/static_operator/fully_connect_op-inl.h +++ b/src/static_operator/fully_connect_op-inl.h @@ -23,7 +23,7 @@ class FullyConnectOp : public StaticOperator { } FullyConnectOp (Param p) { - this->param = p; + this->param_ = p; } virtual std::vector DescribeArgs() const { diff --git a/src/symbol/fully_connect_sym-inl.h b/src/symbol/fully_connect_sym-inl.h index e3f60c945448..65b719c06a68 100644 --- a/src/symbol/fully_connect_sym-inl.h +++ b/src/symbol/fully_connect_sym-inl.h @@ -17,74 +17,37 @@ namespace mxnet { using namespace mxnet::op; -template class FullyConnectSymbol : public AtomicSymbol { public: - virtual std::vector DescribeArguments() const { - std::string ret[] = {"data", "weight", "bias"}; - if (param_.no_bias == 0) { - return std::vector(ret, ret + 3); - } else { - return std::vector(ret, ret + 2); - } - } + virtual std::vector DescribeArguments() const; - virtual std::vector DescribeReturns() const { - std::string temp = "output"; - std::vector v; - v.push_back(temp); - return v; - } + virtual std::vector DescribeReturns() const; - virtual void SetParam(const char *name, const char *val) { - param_.SetParam(name, val); - } + virtual void SetParam(const char *name, const char *val); - virtual void InferShape(std::vector *in_shape, - std::vector *out_shape) const { - using namespace mshadow; - if (param_.no_bias == 0) { - CHECK_EQ(in_shape->size(), 3) << "Input:[data, weight, bias]"; - } else { - CHECK_EQ(in_shape->size(), 2) << "Input:[data, weight]"; - } - CHECK_GT(param_.num_hidden, 0); - const TShape &dshape = (*in_shape)[0]; - CHECK_EQ(dshape.ndim(), 4) << \ - "Input data should be 4D in batch-1-1-hidden"; - CHECK_NE(dshape.ndim(), 0) << "Require data shape to be known"; - ShapeAssignCheck((*in_shape)[1], Shape2(param_.num_hidden, dshape[3])); - if (param_.no_bias == 0) { - ShapeAssignCheck((*in_shape)[2], Shape1(param_.num_hidden)); - } - out_shape->clear(); - out_shape->push_back(dshape); - (*out_shape)[0][3] = param_.num_hidden; - } + virtual bool InferShape(std::vector *in_shape, + std::vector *out_shape) const; /*! * \brief Copy this AtomicSymbol and returns a pointer to the copied object. * this is a virtual function because different subclass of AtomicSymbol would copy differently. * \return a pointer to the copied atomic symbol */ - virtual AtomicSymbol* Copy() const { - FullyConnectSymbol* fc_sym = new FullyConnectSymbol(); - fc_sym->param = this->param; - return fc_sym; - } + virtual AtomicSymbol* Copy() const; + template + StaticOperator* Bind(Context ctx) const; /*! * \brief Bind this AtomicSymbol to a context and get back a static operator * Bind function of AtomicSymbol does not return Operator, but static operator. * Calling bind from the Symbol wrapper would generate a Operator. */ - virtual StaticOperator* Bind_(Context ctx) const { - return new FullyConnectSymbol(param_); + template + StaticOperator* Bind_(Context ctx) const { + return new FullyConnectOp(param_); } - virtual std::string TypeString() const { - return "Fully Connected"; - } + virtual std::string TypeString() const; private: Param param_; }; // class FullyConnectSymbol diff --git a/src/symbol/fully_connect_sym.cc b/src/symbol/fully_connect_sym.cc index 445b7357420f..70c7b54505fa 100644 --- a/src/symbol/fully_connect_sym.cc +++ b/src/symbol/fully_connect_sym.cc @@ -6,8 +6,65 @@ #include "./fully_connect_sym-inl.h" namespace mxnet { - template <> - StaticOperator* FullyConnectSymbol::Bind(Context ctx) const { - return Bind_(ctx); + std::vector FullyConnectSymbol::DescribeArguments() const { + std::string ret[] = {"data", "weight", "bias"}; + if (param_.no_bias == 0) { + return std::vector(ret, ret + 3); + } else { + return std::vector(ret, ret + 2); + } + } + + std::vector FullyConnectSymbol::DescribeReturns() const { + std::string temp = "output"; + std::vector v; + v.push_back(temp); + return v; + } + + void FullyConnectSymbol::SetParam(const char *name, const char *val) { + param_.SetParam(name, val); + } + + bool FullyConnectSymbol::InferShape(std::vector *in_shape, + std::vector *out_shape) const { + using namespace mshadow; + if (param_.no_bias == 0) { + CHECK_EQ(in_shape->size(), 3) << "Input:[data, weight, bias]"; + } else { + CHECK_EQ(in_shape->size(), 2) << "Input:[data, weight]"; + } + CHECK_GT(param_.num_hidden, 0); + const TShape &dshape = (*in_shape)[0]; + CHECK_EQ(dshape.ndim(), 4) << \ + "Input data should be 4D in batch-1-1-hidden"; + CHECK_NE(dshape.ndim(), 0) << "Require data shape to be known"; + ShapeAssignCheck((*in_shape)[1], Shape2(param_.num_hidden, dshape[3])); + if (param_.no_bias == 0) { + ShapeAssignCheck((*in_shape)[2], Shape1(param_.num_hidden)); + } + out_shape->clear(); + out_shape->push_back(dshape); + (*out_shape)[0][3] = param_.num_hidden; + return true; + } + + /*! + * \brief Copy this AtomicSymbol and returns a pointer to the copied object. + * this is a function because different subclass of AtomicSymbol would copy differently. + * \return a pointer to the copied atomic symbol + */ + AtomicSymbol* FullyConnectSymbol::Copy() const { + FullyConnectSymbol* fc_sym = new FullyConnectSymbol(); + fc_sym->param_ = this->param_; + return fc_sym; + } + std::string FullyConnectSymbol::TypeString() const { + return "Fully Connected"; + } + + template<> + StaticOperator* FullyConnectSymbol::Bind(Context ctx) const { + return Bind_(ctx); } } diff --git a/src/symbol/fully_connect_sym.cu b/src/symbol/fully_connect_sym.cu index d94336e376b7..47df3f28d79e 100644 --- a/src/symbol/fully_connect_sym.cu +++ b/src/symbol/fully_connect_sym.cu @@ -7,7 +7,7 @@ namespace mxnet { template <> - StaticOperator* FullyConnectSymbol::Bind(Context ctx) const { + StaticOperator* FullyConnectSymbol::Bind(Context ctx) const { return Bind_(ctx); } } From 118acd2ccca5e7f6c44d9eaa8f5c9f1a65013d53 Mon Sep 17 00:00:00 2001 From: winsty Date: Tue, 21 Jul 2015 15:16:05 +0800 Subject: [PATCH 6/7] fix --- src/static_operator/fully_connect_op-inl.h | 35 +++++++++++++ src/symbol/fully_connect_sym-inl.h | 57 ---------------------- src/symbol/fully_connect_sym.cc | 24 +++++++-- src/symbol/fully_connect_sym.cu | 12 +++-- 4 files changed, 61 insertions(+), 67 deletions(-) delete mode 100644 src/symbol/fully_connect_sym-inl.h diff --git a/src/static_operator/fully_connect_op-inl.h b/src/static_operator/fully_connect_op-inl.h index b3a00151574f..e725109abd18 100644 --- a/src/static_operator/fully_connect_op-inl.h +++ b/src/static_operator/fully_connect_op-inl.h @@ -9,12 +9,14 @@ #include #include +#include #include #include "./static_operator_common.h" #include "./param.h" namespace mxnet { namespace op { + template class FullyConnectOp : public StaticOperator { public: @@ -89,6 +91,39 @@ class FullyConnectOp : public StaticOperator { private: Param param_; }; // class FullyConnectOp + +class FullyConnectSymbol : public AtomicSymbol { + public: + virtual std::vector DescribeArguments() const; + + virtual std::vector DescribeReturns() const; + + virtual void SetParam(const char *name, const char *val); + + virtual bool InferShape(std::vector *in_shape, + std::vector *out_shape) const; + + /*! + * \brief Copy this AtomicSymbol and returns a pointer to the copied object. + * this is a virtual function because different subclass of AtomicSymbol would copy differently. + * \return a pointer to the copied atomic symbol + */ + virtual AtomicSymbol* Copy() const; + + StaticOperator* Bind(Context ctx) const; + /*! + * \brief Bind this AtomicSymbol to a context and get back a static operator + * Bind function of AtomicSymbol does not return Operator, but static operator. + * Calling bind from the Symbol wrapper would generate a Operator. + */ + template + StaticOperator* Bind_(Context ctx) const; + + virtual std::string TypeString() const; + private: + Param param_; +}; // class FullyConnectSymbol + } // namespace op } // namespace mxnet diff --git a/src/symbol/fully_connect_sym-inl.h b/src/symbol/fully_connect_sym-inl.h deleted file mode 100644 index 65b719c06a68..000000000000 --- a/src/symbol/fully_connect_sym-inl.h +++ /dev/null @@ -1,57 +0,0 @@ - /*! - * Copyright (c) 2015 by Contributors - * \file fully_connect_op-inl.h - * \brief fully connect operator - * \author Bing Xu -*/ -#ifndef MXNET_SYMBOL_FULLY_CONNECT_SYM_INL_H_ -#define MXNET_SYMBOL_FULLY_CONNECT_SYM_INL_H_ - -#include -#include -#include -#include -#include "../static_operator/fully_connect_op-inl.h" -#include "../static_operator/param.h" - -namespace mxnet { -using namespace mxnet::op; - -class FullyConnectSymbol : public AtomicSymbol { - public: - virtual std::vector DescribeArguments() const; - - virtual std::vector DescribeReturns() const; - - virtual void SetParam(const char *name, const char *val); - - virtual bool InferShape(std::vector *in_shape, - std::vector *out_shape) const; - - /*! - * \brief Copy this AtomicSymbol and returns a pointer to the copied object. - * this is a virtual function because different subclass of AtomicSymbol would copy differently. - * \return a pointer to the copied atomic symbol - */ - virtual AtomicSymbol* Copy() const; - - template - StaticOperator* Bind(Context ctx) const; - /*! - * \brief Bind this AtomicSymbol to a context and get back a static operator - * Bind function of AtomicSymbol does not return Operator, but static operator. - * Calling bind from the Symbol wrapper would generate a Operator. - */ - template - StaticOperator* Bind_(Context ctx) const { - return new FullyConnectOp(param_); - } - - virtual std::string TypeString() const; - private: - Param param_; -}; // class FullyConnectSymbol -} // namespace mxnet - -#endif // MXNET_SYMBOL_FULLY_CONNECT_SYM_INL_H_ - diff --git a/src/symbol/fully_connect_sym.cc b/src/symbol/fully_connect_sym.cc index 70c7b54505fa..54cbb1cc776d 100644 --- a/src/symbol/fully_connect_sym.cc +++ b/src/symbol/fully_connect_sym.cc @@ -3,9 +3,10 @@ * \file fully_connect_sym.cc * \brief fully connect operator symbol */ -#include "./fully_connect_sym-inl.h" +#include "../static_operator/fully_connect_op-inl.h" namespace mxnet { +namespace op { std::vector FullyConnectSymbol::DescribeArguments() const { std::string ret[] = {"data", "weight", "bias"}; if (param_.no_bias == 0) { @@ -60,11 +61,24 @@ namespace mxnet { return fc_sym; } std::string FullyConnectSymbol::TypeString() const { - return "Fully Connected"; + return "fully_connected"; } template<> - StaticOperator* FullyConnectSymbol::Bind(Context ctx) const { - return Bind_(ctx); + StaticOperator* FullyConnectSymbol::Bind_(Context ctx) const { + return new FullyConnectOp(param_); } -} + + StaticOperator* FullyConnectSymbol::Bind(Context ctx) const { + if (ctx.dev_mask == cpu::kDevMask) { + return Bind_(ctx); + } else { +#if MXNET_USE_CUDA + return Bind_(ctx); +#else + LOG(FATAL) << "GPU is not enabled"; +#endif + } + } +} // namespace op +} // namespace mxnet \ No newline at end of file diff --git a/src/symbol/fully_connect_sym.cu b/src/symbol/fully_connect_sym.cu index 47df3f28d79e..6fcabf019d1a 100644 --- a/src/symbol/fully_connect_sym.cu +++ b/src/symbol/fully_connect_sym.cu @@ -3,11 +3,13 @@ * \file fully_connect_sym.cu * \brief fully connect operator symbol */ -#include "./fully_connect_sym-inl.h" +#include "../static_operator/fully_connect_op-inl.h" namespace mxnet { - template <> - StaticOperator* FullyConnectSymbol::Bind(Context ctx) const { - return Bind_(ctx); +namespace op { + template<> + StaticOperator* FullyConnectSymbol::Bind_(Context ctx) const { + return new FullyConnectOp(param_); } -} +} // namespace op +} // namespace mxnet \ No newline at end of file From 4722b73fac74b331ad7f6d40103e8d171dc6f732 Mon Sep 17 00:00:00 2001 From: winsty Date: Mon, 27 Jul 2015 20:35:39 +0800 Subject: [PATCH 7/7] fix style and doc --- include/mxnet/narray.h | 2 +- include/mxnet/operator.h | 32 ++++++----- include/mxnet/static_operator.h | 2 +- include/mxnet/symbol.h | 8 +-- src/operator/static_operator_wrapper.cc | 12 ++--- .../operator}/static_operator_wrapper.h | 18 +++---- src/static_operator/fully_connect_op-inl.h | 53 +++++++++++-------- src/static_operator/static_operator.cc | 12 ++--- src/symbol/fully_connect_sym.cc | 11 ++-- 9 files changed, 79 insertions(+), 71 deletions(-) rename {include/mxnet => src/operator}/static_operator_wrapper.h (88%) diff --git a/include/mxnet/narray.h b/include/mxnet/narray.h index 458bf9f6c834..4e7b4448e667 100644 --- a/include/mxnet/narray.h +++ b/include/mxnet/narray.h @@ -73,7 +73,7 @@ class NArray { DAGEngine::Get()->WaitForVar(ptr_->var); } /*! \return the associated DAG variable of the narray.*/ - inline DAGEngine::Variable Var() const { + inline DAGEngine::Variable var() const { return ptr_->var; } /*! diff --git a/include/mxnet/operator.h b/include/mxnet/operator.h index 666a36221a92..ea1e44990e9d 100644 --- a/include/mxnet/operator.h +++ b/include/mxnet/operator.h @@ -17,12 +17,11 @@ namespace mxnet { /*! - * \brief static operator interface (current interface have not yet todo with scheduler), - * operator is a stateful object that can be used to call forward and backprop - * - * This interface relies on pre-allocated memory in TBlob, the caller need to set - * the memory region in TBlob correctly before calling Forward and Backward + * \brief operator interface + * operator is an object can be scheduled by DAG engine directly. * + * This interface relies on NArray. The user should prepare the input NArray and + * output NArray by themselves. * \sa Operator */ class Operator { @@ -33,17 +32,20 @@ class Operator { */ virtual int DescribeProperty() const; /*! - * \brief set the context of the Operator + * \brief set the global context of the Operator * \param ctx the context to be set to */ virtual void SetContext(Context ctx); /*! - * \brief perform a forward operation of operator, save the output to TBlob + * \brief perform a forward operation of operator, save the output to NArray + * This method only pushes an execution request to the DAG engine, and + * return immediately. Actual execution is conducted by the DAG engine. * \param opt option on Forward such as whether this is training phase * \param ctx runtime context * \param in_data array of input data, it is const * \param out_data array of output data, - * the space of TBlob in out_data must be pre-allocated with InferShape + * the space of NArray in out_data must be pre-allocated with InferShape + * \sa NArray */ virtual void Forward(Option opt, RunContext ctx, @@ -51,14 +53,15 @@ class Operator { const std::vector &out_data) = 0; /*! * \brief perform a backward operation of the operator to get the gradient + * This method only pushes an execution request to the DAG engine, and + * return immediately. Actual execution is conducted by the DAG engine. * \param ctx runtime context - * \param grad_next the gradient value we get from output of the operator + * \param grad_next the gradient value of the output of the operator, used by chain rule. * \param in_data the array of input data - * \param out_grad array of output gradient, there could be three possible TBlob - * in the each element in the array + * \param out_grad array of output gradient * \param req request types of the gradient saving operation * only inplace will change input data - * \sa GradReqType + * \sa GradReqType, NArray */ virtual void Backward(RunContext ctx, const std::vector &grad_next, @@ -67,7 +70,10 @@ class Operator { const std::vector &req) = 0; protected: + /** + * \brief the global context denots the device info. + */ Context global_ctx; -}; +}; // class operator } // namespace mxnet #endif // MXNET_OPERATOR_H_ diff --git a/include/mxnet/static_operator.h b/include/mxnet/static_operator.h index b1ae8fd23de1..2f989bf5aca8 100644 --- a/include/mxnet/static_operator.h +++ b/include/mxnet/static_operator.h @@ -60,7 +60,7 @@ class StaticOperator { const std::vector &out_grad, const std::vector &req) = 0; /*! - * \brief factory unction, create a new StaticOperator + * \brief factory function, create a new StaticOperator * \param type the type of StaticOperator * \param ctx the context device type of StaticOperator * \return a pointer of StaticOperator object diff --git a/include/mxnet/symbol.h b/include/mxnet/symbol.h index 6f0b5ef8af93..0b69005f7a16 100644 --- a/include/mxnet/symbol.h +++ b/include/mxnet/symbol.h @@ -15,9 +15,9 @@ #include #include "./base.h" #include "./tensor_blob.h" +#include "./operator.h" namespace mxnet { -class NArrayOperator; /*! * \brief Symbol is the wrapper of AtomicSymbol, the reason for this indirection is that Symbol * should support expressions and often passed by value. While AtomicSymbol have many subclasses, @@ -67,11 +67,11 @@ class Symbol { */ virtual ~Symbol() {} /*! - * \brief bind to device and returns an NArrayOperator. + * \brief bind to device and returns an operator. * \param ctx context of the operator - * \return returns the pointer to a created NArrayOperator. It is on the user to delete. + * \return returns the pointer to a created operator. It is on the user to delete. */ - virtual NArrayOperator* Bind(Context ctx) const { return nullptr; } + virtual Operator* Bind(Context ctx) const { return nullptr; } /*! * \brief copy the symbol * \return a deep copy of the graph diff --git a/src/operator/static_operator_wrapper.cc b/src/operator/static_operator_wrapper.cc index 07834739ae0b..7304e071461d 100644 --- a/src/operator/static_operator_wrapper.cc +++ b/src/operator/static_operator_wrapper.cc @@ -4,7 +4,7 @@ * \brief the implementation of narray operator * \author Naiyan Wang */ -#include +#include "./static_operator_wrapper.h" namespace mxnet { @@ -37,11 +37,11 @@ namespace mxnet { std::vector in; std::vector out; for (size_t i = 0; i < in_data.size(); ++i) { - used_var.push_back(in_data[i].Var()); + used_var.push_back(in_data[i].var()); in.push_back(in_data[i].data()); } for (size_t i = 0; i < out_data.size(); ++i) { - mutate_var.push_back(out_data[i].Var()); + mutate_var.push_back(out_data[i].var()); out.push_back(out_data[i].data()); } DAGEngine::Get()->Push([this, opt, ctx, in, out](RunContext ctx) { @@ -70,15 +70,15 @@ namespace mxnet { std::vector grad_out; std::vector data; for (size_t i = 0; i < grad_next.size(); ++i) { - used_var.push_back(grad_next[i].Var()); + used_var.push_back(grad_next[i].var()); grad_in.push_back(grad_next[i].data()); } for (size_t i = 0; i < in_data.size(); ++i) { - used_var.push_back(in_data[i].Var()); + used_var.push_back(in_data[i].var()); data.push_back(in_data[i].data()); } for (size_t i = 0; i < out_grad.size(); ++i) { - mutate_var.push_back(out_grad[i].Var()); + mutate_var.push_back(out_grad[i].var()); grad_out.push_back(out_grad[i].data()); } DAGEngine::Get()->Push([this, ctx, grad_in, grad_out, data, req](RunContext ctx) { diff --git a/include/mxnet/static_operator_wrapper.h b/src/operator/static_operator_wrapper.h similarity index 88% rename from include/mxnet/static_operator_wrapper.h rename to src/operator/static_operator_wrapper.h index 6544269a3738..6897ff9751f6 100644 --- a/include/mxnet/static_operator_wrapper.h +++ b/src/operator/static_operator_wrapper.h @@ -4,17 +4,17 @@ * \brief operator interface of mxnet * \author Naiyan Wang */ -#ifndef MXNET_STATIC_OPERATOR_WRAPPER_H_ -#define MXNET_STATIC_OPERATOR_WRAPPER_H_ +#ifndef MXNET_OPERATOR_STATIC_OPERATOR_WRAPPER_H_ +#define MXNET_OPERATOR_STATIC_OPERATOR_WRAPPER_H_ // this file will be seen by cuda, no c++11 for now #include +#include +#include +#include +#include +#include +#include #include -#include "./base.h" -#include "./tensor_blob.h" -#include "./static_operator.h" -#include "./operator.h" -#include "./narray.h" -#include "./dag_engine.h" namespace mxnet { /*! @@ -78,4 +78,4 @@ class StaticOperatorWrapper: public Operator { StaticOperator* op; }; } // namespace mxnet -#endif // MXNET_STATIC_OPERATOR_WRAPPER_H_ +#endif // MXNET_OPERATOR_STATIC_OPERATOR_WRAPPER_H_ diff --git a/src/static_operator/fully_connect_op-inl.h b/src/static_operator/fully_connect_op-inl.h index e725109abd18..7cbf05653903 100644 --- a/src/static_operator/fully_connect_op-inl.h +++ b/src/static_operator/fully_connect_op-inl.h @@ -1,8 +1,7 @@ /*! * Copyright (c) 2015 by Contributors * \file fully_connect_op-inl.h - * \brief fully connect operator - * \author Bing Xu + * \brief fully connect operator and symbol */ #ifndef MXNET_STATIC_OPERATOR_FULLY_CONNECT_OP_INL_H_ #define MXNET_STATIC_OPERATOR_FULLY_CONNECT_OP_INL_H_ @@ -11,31 +10,34 @@ #include #include #include +#include #include "./static_operator_common.h" #include "./param.h" namespace mxnet { namespace op { - +/** + * \brief This is the implementation of fully connected layer. + * + * \tparam xpu The device that the op will be executed on. + */ template class FullyConnectOp : public StaticOperator { public: - FullyConnectOp () { + /*! + * \brief default constructor + */ + FullyConnectOp() { // Do nothing. } - FullyConnectOp (Param p) { + /*! + * \brief constructor with parameters. Used in Bind() in corresponding symbol. + */ + explicit FullyConnectOp(Param p) { this->param_ = p; } - virtual std::vector DescribeArgs() const { - ArgType ret[] = {kDataArg, kWeightArg, kBiasArg}; - if (param_.no_bias == 0) { - return std::vector(ret, ret + 3); - } else { - return std::vector(ret, ret + 2); - } - } virtual void Forward(Option opt, RunContext ctx, const std::vector &in_data, @@ -57,6 +59,7 @@ class FullyConnectOp : public StaticOperator { out += repmat(bias, data.size(0)); } } + virtual void Backward(RunContext ctx, const std::vector &grad_next, const std::vector &in_data, @@ -88,10 +91,15 @@ class FullyConnectOp : public StaticOperator { Tensor gdata = out_grad[0].FlatTo2D(s); Assign(gdata, req[0], dot(grad, wmat)); } + private: + /** The param of the fully connected layer.*/ Param param_; }; // class FullyConnectOp +/** + * @brief The symbol part of the fully connected layer. + */ class FullyConnectSymbol : public AtomicSymbol { public: virtual std::vector DescribeArguments() const; @@ -103,24 +111,23 @@ class FullyConnectSymbol : public AtomicSymbol { virtual bool InferShape(std::vector *in_shape, std::vector *out_shape) const; - /*! - * \brief Copy this AtomicSymbol and returns a pointer to the copied object. - * this is a virtual function because different subclass of AtomicSymbol would copy differently. - * \return a pointer to the copied atomic symbol - */ virtual AtomicSymbol* Copy() const; StaticOperator* Bind(Context ctx) const; - /*! - * \brief Bind this AtomicSymbol to a context and get back a static operator - * Bind function of AtomicSymbol does not return Operator, but static operator. - * Calling bind from the Symbol wrapper would generate a Operator. + + virtual std::string TypeString() const; + + /** + * @brief This is the template function of bind() implementation. + * + * @param ctx The device context + * @return A device dependent static operator can be used for execution. */ template StaticOperator* Bind_(Context ctx) const; - virtual std::string TypeString() const; private: + /** The param of the fully connected layer.*/ Param param_; }; // class FullyConnectSymbol diff --git a/src/static_operator/static_operator.cc b/src/static_operator/static_operator.cc index 67464fb394b6..4a2a121532dd 100644 --- a/src/static_operator/static_operator.cc +++ b/src/static_operator/static_operator.cc @@ -12,11 +12,12 @@ namespace mxnet { namespace op { -// declare the operator -template -StaticOperator *CreateOperator(OpType type); - - +/** + * @brief return a OpType based on string description + * + * @param type the string description of operators + * @return the OpType indicated the type of operators + */ OpType GetOpType(const char *type) { if (!strcmp(type, "relu")) return kReLU; if (!strcmp(type, "fullc")) return kFullc; @@ -25,7 +26,6 @@ OpType GetOpType(const char *type) { } } // namespace op -// implementing the context StaticOperator *StaticOperator::Create(const char *type, Context ctx) { op::OpType otype = op::GetOpType(type); diff --git a/src/symbol/fully_connect_sym.cc b/src/symbol/fully_connect_sym.cc index 54cbb1cc776d..6ef21437b43e 100644 --- a/src/symbol/fully_connect_sym.cc +++ b/src/symbol/fully_connect_sym.cc @@ -23,7 +23,7 @@ namespace op { return v; } - void FullyConnectSymbol::SetParam(const char *name, const char *val) { + void FullyConnectSymbol::SetParam(const char *name, const char *val) { param_.SetParam(name, val); } @@ -50,11 +50,6 @@ namespace op { return true; } - /*! - * \brief Copy this AtomicSymbol and returns a pointer to the copied object. - * this is a function because different subclass of AtomicSymbol would copy differently. - * \return a pointer to the copied atomic symbol - */ AtomicSymbol* FullyConnectSymbol::Copy() const { FullyConnectSymbol* fc_sym = new FullyConnectSymbol(); fc_sym->param_ = this->param_; @@ -80,5 +75,5 @@ namespace op { #endif } } -} // namespace op -} // namespace mxnet \ No newline at end of file +} // namespace op +} // namespace mxnet