diff --git a/Makefile b/Makefile index ff8cd7fd4d02..b159e0bc9429 100644 --- a/Makefile +++ b/Makefile @@ -40,15 +40,16 @@ endif ifneq ($(ADD_CFLAGS), NONE) CFLAGS += $(ADD_CFLAGS) + CFLAGS += -DDMLC_USE_CXX11=1 endif ifneq ($(ADD_LDFLAGS), NONE) LDFLAGS += $(ADD_LDFLAGS) endif -OBJ = storage.o narray_op_cpu.o -OBJCXX11 = engine.o narray.o -CUOBJ = narray_op_gpu.o +OBJ = storage.o narray_op_cpu.o operator.o operator_cpu.o +OBJCXX11 = engine.o narray.o +CUOBJ = narray_op_gpu.o operator_gpu.o LIB_DEP = $(DMLC_CORE)/libdmlc.a @@ -64,6 +65,9 @@ engine.o: src/dag_engine/simple_engine.cc narray.o: src/narray/narray.cc narray_op_cpu.o: src/narray/narray_op_cpu.cc src/narray/narray_op-inl.h narray_op_gpu.o: src/narray/narray_op_gpu.cu src/narray/narray_op-inl.h +operator.o: src/operator/operator.cc +operator_cpu.o: src/operator/operator_cpu.cc +operator_gpu.o: src/operator/operator_gpu.cu $(BIN) : $(CXX) $(CFLAGS) -o $@ $(filter %.cpp %.o %.c %.a %.cc, $^) $(LDFLAGS) @@ -72,7 +76,7 @@ $(OBJ) : $(CXX) -c $(CFLAGS) -o $@ $(firstword $(filter %.cpp %.c %.cc, $^) ) $(OBJCXX11) : - $(CXX) -std=c++0x -c $(CFLAGS) -o $@ $(firstword $(filter %.cpp %.c %.cc, $^) ) + $(CXX) -std=c++11 -c $(CFLAGS) -o $@ $(firstword $(filter %.cpp %.c %.cc, $^) ) $(SLIB) : $(CXX) $(CFLAGS) -shared -o $@ $(filter %.cpp %.o %.c %.a %.cc, $^) $(LDFLAGS) diff --git a/include/mxnet/narray.h b/include/mxnet/narray.h index 8e3a398fd9ad..287d9761a736 100644 --- a/include/mxnet/narray.h +++ b/include/mxnet/narray.h @@ -7,6 +7,7 @@ #define MXNET_NARRAY_H_ #include #include +#include #include "./base.h" #include "./storage.h" #include "./tensor_blob.h" @@ -25,7 +26,7 @@ class NArray { /*! \brief default cosntructor */ NArray() {} /*! - * \brief constructing a new dynamic NArray + * \brief constructing a new dynamic NArray * \param shape the shape of array * \param ctx context of NArray */ @@ -34,16 +35,16 @@ class NArray { } /*! * \brief constructing a static NArray that shares data with TBlob - * Use with caution: allocate ONLY ONE NArray for each TBlob, + * Use with caution: allocate ONLY ONE NArray for each TBlob, * make sure the memory region is available through out the life of NArray * \param data the memory content of static data * \param dev_id the device id this tensor sits at - */ + */ NArray(const TBlob &data, int dev_id) : ptr_(new Chunk(data, dev_id)) { } /*! - * \return the shape of current NArray + * \return the shape of current NArray */ inline const TShape &shape() const { return ptr_->data.shape_; @@ -57,7 +58,7 @@ class NArray { /*! \return whether this narray is not initialized */ inline bool is_empty() const { return ptr_.get() == nullptr; - } + } private: /*! \brief the real data chunk that backs NArray */ @@ -79,7 +80,7 @@ class NArray { Chunk() : static_data(true), delay_alloc(false) { var = DAGEngine::Get()->NewVar(); } - /*! \brief construct from static data */ + /*! \brief construct from static data */ Chunk(const TBlob &data, int dev_id) : data(data), static_data(true), @@ -118,14 +119,14 @@ class NArray { /*! \brief internal data of NArray */ std::shared_ptr ptr_; /*! - * \brief constructing a new dynamic NArray + * \brief constructing a new dynamic NArray * \param shape the shape of array * \param ctx context of NArray * \param delay_alloc whether delay the allocation */ NArray(const TShape &shape, Context ctx, bool delay_alloc) : ptr_(new Chunk(shape, ctx, delay_alloc)) { - } + } // add friend to helper functions template friend NArray BinaryEWise(const NArray &lhs, const NArray &rhs); diff --git a/include/mxnet/operator.h b/include/mxnet/operator.h index fbe2e2c8f6af..a9b3c9f2b3ae 100644 --- a/include/mxnet/operator.h +++ b/include/mxnet/operator.h @@ -1,13 +1,13 @@ /*! * Copyright (c) 2015 by Contributors * \file operator.h - * \brief operator interface of mxnet + * \brief static operator interface of mxnet */ #ifndef MXNET_OPERATOR_H_ #define MXNET_OPERATOR_H_ +// this file will be seen by cuda, no c++11 for now #include #include "./base.h" -#include "./narray.h" #include "./tensor_blob.h" namespace mxnet { @@ -38,24 +38,64 @@ class Operator { /*! \brief add to the provided space */ kAddTo = 3 }; + /*! \brief input argument type of the operator have */ + enum ArgType { + /*! \brief data argument */ + kDataArg = 0, + /*! \brief weight argument */ + kWeightArg = 1, + /*! \brief bias argument */ + kBiasArg = 2 + }; + enum Property { + /*! \brief Op contains interanl state, won't influence engine schedule */ + kContainInteralState = 1, + /*! \brief Op forward require random number, will influence engine schedule */ + kForwardRequireRnd = 2, + }; + /*! + * \brief get types of input argument of this oeprator + * \return a vector corresponding to type of each argument + * this order is same as the order of inputs in Forward, InferShape and Backward + */ + virtual std::vector DescribeArgs() const { + // default most of layers only have one data argument + return std::vector(1, kDataArg); + } + /*! + * \brief describe property of op + * \return a bit map in int + */ + virtual int DescribeProperty() const { + // default most of layer only conatin internal state + return kContainInteralState; + } /*! * \brief set param for the operator from string * \param name parameter name * \param val string for configuration */ - virtual void SetParam(const char *name, const char *val) {} + virtual void SetParam(const char *name, const char *val) {} /*! - * \brief inter the shape of output given the input data + * \brief inter the shapes of outputs and unknown input arguments * \param in_shape the shape of input arguments of the operator + * this should be of same length as the vector returned by DescribeArgs + * in_shape allows unknown elements, which are checked by shape.ndim() == 0. + * For unknown shapes, InferShape will try to fill in the correct Shape in in_shape + * For known shapes, InferShape will check shape consistency + * + * common practice: set the shape of data input, and usually weight's shape can be infered + * * \param out_shape the shape of outputs of the operator + * InferShape will modify the vector to fill output TShape */ - virtual void InferShape(const std::vector &in_shape, + virtual void InferShape(std::vector *in_shape, std::vector *out_shape) = 0; /*! * \brief perform a forward operation of operator, save the output to TBlob * \param opt option on Forward such as whether this is training phase * \param ctx runtime context - * \param in_data array of input data + * \param in_data array of input data, it is const * \param out_data array of output data, * the space of TBlob in out_data must be pre-allocated with InferShape */ @@ -71,13 +111,21 @@ class Operator { * \param out_grad array of output gradient, there could be three possible TBlob * in the each element in the array * \param req_types request types of the gradient saving operation + * only inplace will change input data * \sa GradReqType */ virtual void Backward(RunContext ctx, const std::vector &grad_next, const std::vector &in_data, const std::vector &out_grad, - const std::vector req); + const std::vector &req); + + /*! + * \brief factory unction, create a new operator + * \param type the type of operator + * \param ctx the context device type of operator + */ + static Operator *Create(const char *type, Context ctx); }; } // namespace mxnet #endif // MXNET_OPERATOR_H_ diff --git a/src/dag_engine/simple_engine.cc b/src/dag_engine/simple_engine.cc index 2e35b2ff57fc..9ea42e979735 100644 --- a/src/dag_engine/simple_engine.cc +++ b/src/dag_engine/simple_engine.cc @@ -1,19 +1,18 @@ #include #include - namespace mxnet { class SimpleEngine : public DAGEngine { public: virtual void Push(AsyncOp exec_fun, Context exec_ctx, - const std::vector &use_vars, + const std::vector &use_vars, const std::vector &mutate_vars) { // cannot schedule async using naive way because deps are not captured LOG(FATAL) << "cannot schedule async operations"; } virtual void Push(Op exec_fun, Context exec_ctx, - const std::vector &use_vars, + const std::vector &use_vars, const std::vector &mutate_vars) { exec_fun(RunContext()); } @@ -25,7 +24,7 @@ class SimpleEngine : public DAGEngine { // that have the info about the variable // use ptr directly instead of ID because this avoids an indirect mapping return NULL; - } + } }; // implements the singleton factory DAGEngine* DAGEngine::Get() { diff --git a/src/narray/narray_op-inl.h b/src/narray/narray_op-inl.h index 918149ff298b..9891d9a993d0 100644 --- a/src/narray/narray_op-inl.h +++ b/src/narray/narray_op-inl.h @@ -19,7 +19,7 @@ namespace mxnet { namespace narray { // true implementation template -inline void Eval_(const TBlob &lhs, const TBlob &rhs, TBlob ret, RunContext ctx) { +inline void Eval_(const TBlob &lhs, const TBlob &rhs, TBlob ret, RunContext ctx) { using namespace mshadow::expr; mshadow::Stream *s = static_cast*>(ctx.stream); ret.FlatTo2D(s) diff --git a/src/operator/activation_op-inl.h b/src/operator/activation_op-inl.h new file mode 100644 index 000000000000..2a412ef3b2e1 --- /dev/null +++ b/src/operator/activation_op-inl.h @@ -0,0 +1,60 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file activation_op-inl.h + * \brief activation operator of mxnet + */ + +#ifndef MXNET_OPERATOR_ACTIVATION_OP_INL_H_ +#define MXNET_OPERATOR_ACTIVATION_OP_INL_H_ + +#include +#include +#include +#include "./operator_common.h" + +namespace mxnet { +namespace op { +template +class ActivationOp : public Operator { + public: + virtual void InferShape(std::vector *in_shape, + std::vector *out_shape) { + CHECK(in_shape->size() == 1) << "Only 1 input is allowed"; + CHECK((*in_shape)[0].ndim() != 0 ) << "Require data shape to be known"; + out_shape->clear(); + out_shape->push_back((*in_shape)[0]); + } + virtual void Forward(Option opt, + RunContext ctx, + const std::vector &in_data, + const std::vector &out_data) { + CHECK(out_data.size() == 1); + CHECK(in_data.size() == 1); + mshadow::Stream *stream = \ + static_cast *>(ctx.stream); + mshadow::Tensor in = in_data[0].FlatTo2D(stream); + mshadow::Tensor out = out_data[0].FlatTo2D(stream); + out = mshadow::expr::F(in); + } + virtual void Backward(RunContext ctx, + const std::vector &grad_next, + const std::vector &in_data, + const std::vector &out_grad, + const std::vector &req) { + CHECK(grad_next.size() == 1); + CHECK(in_data.size() == 1); + CHECK(out_grad.size() == 1); + CHECK(req.size() == 1); + mshadow::Stream *stream = \ + static_cast *>(ctx.stream); + mshadow::Tensor grad = grad_next[0].FlatTo2D(stream); + mshadow::Tensor data = in_data[0].FlatTo2D(stream); + mshadow::Tensor out = out_grad[0].FlatTo2D(stream); + Assign(out, req[0], mshadow::expr::F( + mshadow::expr::F(data)) * grad); + } +}; // class ActivationOp +} // namespace op +} // namespace mxnet + +#endif // MXNET_OPERATOR_ACTIVATION_OP_INL_H_ diff --git a/src/operator/fully_connect_op-inl.h b/src/operator/fully_connect_op-inl.h new file mode 100644 index 000000000000..a7f07601b374 --- /dev/null +++ b/src/operator/fully_connect_op-inl.h @@ -0,0 +1,108 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file fully_connect_op-inl.h + * \brief fully connect operator + * \author Bing Xu +*/ +#ifndef MXNET_OPERATOR_FULLY_CONNECT_OP_INL_H_ +#define MXNET_OPERATOR_FULLY_CONNECT_OP_INL_H_ + +#include +#include +#include +#include "./operator_common.h" +#include "./param.h" + +namespace mxnet { +namespace op { +template +class FullyConnectOp : public Operator { + public: + virtual std::vector DescribeArgs() const { + ArgType ret[] = {kDataArg, kWeightArg, kBiasArg}; + if (param_.no_bias == 0) { + return std::vector(ret, ret + 3); + } else { + return std::vector(ret, ret + 2); + } + } + virtual void SetParam(const char *name, const char *val) { + param_.SetParam(name, val); + } + virtual void InferShape(std::vector *in_shape, + std::vector *out_shape) { + using namespace mshadow; + if (param_.no_bias == 0) { + CHECK(in_shape->size() == 3) << "Input:[data, weight, bias]"; + } else { + CHECK(in_shape->size() == 2) << "Input:[data, weight]"; + } + CHECK(param_.num_hidden > 0); + const TShape &dshape = (*in_shape)[0]; + CHECK(dshape.ndim() == 4) << \ + "Input data should be 4D in batch-1-1-hidden"; + CHECK(dshape.ndim() != 0) << "Require data shape to be known"; + ShapeAssignCheck((*in_shape)[1], Shape2(param_.num_hidden, dshape[3])); + if (param_.no_bias == 0) { + ShapeAssignCheck((*in_shape)[2], Shape1(param_.num_hidden)); + } + out_shape->clear(); + out_shape->push_back(dshape); + (*out_shape)[0][3] = param_.num_hidden; + } + virtual void Forward(Option opt, + RunContext ctx, + const std::vector &in_data, + const std::vector &out_data) { + using namespace mshadow; + using namespace mshadow::expr; + size_t expected = param_.no_bias == 0 ? 3 : 2; + CHECK(in_data.size() == expected); + CHECK(out_data.size() == 1); + Stream *s = static_cast *>(ctx.stream); + Tensor data = in_data[0].FlatTo2D(s); + Tensor wmat = in_data[1].get(s); + Tensor out = out_data[0].FlatTo2D(s); + out = dot(data, wmat.T()); + if (param_.no_bias == 0) { + Tensor bias = in_data[2].get(s); + out += repmat(bias, data.size(0)); + } + } + virtual void Backward(RunContext ctx, + const std::vector &grad_next, + const std::vector &in_data, + const std::vector &out_grad, + const std::vector &req) { + using namespace mshadow; + using namespace mshadow::expr; + CHECK(grad_next.size() == 1); + size_t expected = param_.no_bias == 0 ? 3 : 2; + CHECK(in_data.size() == expected && out_grad.size() == expected); + CHECK(req.size() == 3); + Stream *s = static_cast *>(ctx.stream); + Tensor data = in_data[0].FlatTo2D(s); + Tensor wmat = in_data[1].get(s); + Tensor grad = grad_next[0].FlatTo2D(s); + // backprop + CHECK(req[1] != kWriteInplace) << "cannot write weight inplace"; + // gradient of weight + Tensor gwmat = out_grad[1].get(s); + Assign(gwmat, req[1], dot(grad.T(), data)); + // gradient of bias + if (param_.no_bias == 0) { + Tensor gbias = out_grad[2].get(s); + Assign(gbias, req[2], sum_rows(grad)); + } + // gradient of data + Tensor gdata = out_grad[0].FlatTo2D(s); + Assign(gdata, req[0], dot(grad, wmat)); + } + private: + Param param_; +}; // class FullyConnectOp +} // namespace op +} // namespace mxnet + +#endif // MXNET_OPERATOR_FULLY_CONNECT_OP_INL_H_ + diff --git a/src/operator/mshadow_op.h b/src/operator/mshadow_op.h new file mode 100644 index 000000000000..7c2f0c7b6a76 --- /dev/null +++ b/src/operator/mshadow_op.h @@ -0,0 +1,106 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file mshadow_op.h + * \brief extra mshadow operation for mxnet + * \author Bing Xu + */ +#ifndef MXNET_MSHADOW_OPERATOR_OP_H_ +#define MXNET_MSHADOW_OPERATOR_OP_H_ +#include +#include + +namespace mxnet { +/*! \brief operations for ActivationLayer */ +namespace op { +struct identity { + MSHADOW_XINLINE static real_t Map(real_t a) { + return a; + } +}; +struct identity_grad { + MSHADOW_XINLINE static real_t Map(real_t a) { + return 1.0f; + } +}; + +/*! \brief sigmoid unit */ +struct sigmoid { + MSHADOW_XINLINE static real_t Map(real_t a) { + return 1.0f / (1.0f + expf(-a)); + } +}; +struct sigmoid_grad { + MSHADOW_XINLINE static real_t Map(real_t a) { + return a * (1.0f - a); + } +}; +/*! \brief Rectified Linear Operation */ +struct relu { + MSHADOW_XINLINE static real_t Map(real_t a) { + return a > 0.0f ? a : 0.0f; + } +}; +struct relu_grad { + MSHADOW_XINLINE static real_t Map(real_t a) { + return a > 0.0f ? 1.0f : 0.0f; + } +}; + +/*! \brief Leaky ReLU Operation */ +struct xelu { + MSHADOW_XINLINE static real_t Map(real_t a, real_t b) { + return a > 0 ? a : a / b; + } +}; + +struct xelu_grad { + MSHADOW_XINLINE static real_t Map(real_t a, real_t b) { + return a > 0 ? 1 : 1.0f / b; + } +}; + +struct tanh { + MSHADOW_XINLINE static real_t Map(real_t a) { + return tanhf( a ); + } +}; + +struct tanh_grad { + MSHADOW_XINLINE static real_t Map(real_t a) { + return 1.0f - a * a; + } +}; + + +struct square { + MSHADOW_XINLINE static real_t Map(real_t a) { + return a * a; + } +}; + +/*! \brief used for generate Bernoulli mask */ +struct threshold { + MSHADOW_XINLINE static real_t Map(real_t a, real_t b) { + return a < b ? 1.0f : 0.0f; + } +}; + +/*! \brief used for generate element of power */ +struct power { + MSHADOW_XINLINE static real_t Map(real_t a, real_t b) { + return powf( a, b ); + } +}; + +/*!\ \brief used for generate element sqrt */ +struct square_root { + MSHADOW_XINLINE static real_t Map(real_t a) { + return sqrt(a); + } +}; + +} // namespace op +} // namespace mxnet + +#endif // MXNET_MSHADOW_OPERATOR_OP_H_ + diff --git a/src/operator/operator-inl.h b/src/operator/operator-inl.h new file mode 100644 index 000000000000..7bdd0a1b96d1 --- /dev/null +++ b/src/operator/operator-inl.h @@ -0,0 +1,35 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file operator-inl.h + * \brief device invarient code to create operators + * \author Bing Xu +*/ +#ifndef MXNET_OPERATOR_INL_H_ +#define MXNET_OPERATOR_INL_H_ +#include +#include +#include "./mshadow_op.h" +#include "./activation_op-inl.h" +#include "./fully_connect_op-inl.h" + +namespace mxnet { +namespace op { +/*! + * \brief device invariant function to create operators + * \param type the type of operator + * \tparam xpu the device type we are at + */ +template +inline Operator *CreateOperator_(OpType type) { + switch (type) { + case kReLU: + return new ActivationOp(); + case kFullc: + return new FullyConnectOp(); + default: LOG(FATAL) << "unknown OpType"; + } + return NULL; +} +} // namespace op +} // namespace mxnet +#endif // MXNET_OPERATOR_INL_H_ diff --git a/src/operator/operator.cc b/src/operator/operator.cc new file mode 100644 index 000000000000..e56d6049eca9 --- /dev/null +++ b/src/operator/operator.cc @@ -0,0 +1,39 @@ +#include +#include +#include +#include +#include "./operator_common.h" + +namespace mxnet { +namespace op { +// declare the operator +template +Operator *CreateOperator(OpType type); + + +OpType GetOpTpe(const char *type) { + if (!strcmp(type, "relu")) return kReLU; + if (!strcmp(type, "fullc")) return kFullc; + LOG(FATAL) << "unknown operator type " << type; + return kReLU; +} +} + +// implementing the context +Operator *Operator::Create(const char *type, + Context ctx) { + op::OpType otype = op::GetOpTpe(type); + if (ctx.dev_mask == cpu::kDevMask) { + return op::CreateOperator(otype); + } + if (ctx.dev_mask == gpu::kDevMask) { +#if MXNET_USE_CUDA + return op::CreateOperator(otype); +#else + LOG(FATAL) << "GPU is not enabled"; +#endif + } + return NULL; +} + +} // namespace mxnet diff --git a/src/operator/operator_common.h b/src/operator/operator_common.h new file mode 100644 index 000000000000..8fb1066333b3 --- /dev/null +++ b/src/operator/operator_common.h @@ -0,0 +1,67 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file operator_common.h + * \brief common internal header of most operators + * this header includes utility functions operator can use + * common type definitions + * \author Bing Xu +*/ +#ifndef MXNET_OPERATOR_OPERATOR_COMMON_H_ +#define MXNET_OPERATOR_OPERATOR_COMMON_H_ + +#include +#include + +namespace mxnet { +namespace op { +/*! + * \brief assign the expression to out according to request + * \param out the data to be assigned + * \param req the assignment request + * \param exp the expression + * \tparam OType output type + * \tparam Exp expression type + */ +template +inline void Assign(OType &out, + Operator::GradReqType req, + const Exp &exp) { + switch (req) { + case Operator::kNullOp: break; + case Operator::kWriteTo: + case Operator::kWriteInplace: out = exp; break; + case Operator::kAddTo: out += exp; break; + default: LOG(FATAL) << "not reached"; + } +} +/*! + * \brief assign shape to out if out is unknown + * otherwise check consistency + * \param out the output shape to be stored + * \param shape the infered shape + */ +template +inline void ShapeAssignCheck(TShape &out, const TS &shape) { + if (out.ndim() == 0) { + out = shape; + } else { + CHECK(out == shape) << "InferShape:: shape inconsistent"; + } +} + +/*! \brief type of operators */ +enum OpType { + kReLU = 0, + kFullc = 1 +}; + +/*! + * \brief device invariant function to create operators + * \param type the type of operator + * \tparam xpu the device type we are at + */ +template +Operator *CreateOperator(OpType type); +} //namespace op +} // namespace mxnet +#endif // MXNET_OPERATOR_COMMON_H_ diff --git a/src/operator/operator_cpu.cc b/src/operator/operator_cpu.cc new file mode 100644 index 000000000000..3d5e7c5f3248 --- /dev/null +++ b/src/operator/operator_cpu.cc @@ -0,0 +1,18 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file operator_cpu.cc + * \brief CPU specialization of operator codes + * \author Bing Xu +*/ +#include "./operator-inl.h" + +namespace mxnet { +namespace op { + +template<> +Operator *CreateOperator(OpType type) { + return CreateOperator_(type); +} + +} // namespace op +} // namespace mxnet diff --git a/src/operator/operator_gpu.cu b/src/operator/operator_gpu.cu new file mode 100644 index 000000000000..8fb3b2751f13 --- /dev/null +++ b/src/operator/operator_gpu.cu @@ -0,0 +1,21 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file operator_gpu.cu + * \brief GPU specialization of operator code + * \author Bing Xu +*/ +#include +#include +#include "operator-inl.h" + +namespace mxnet { +namespace op { + +template<> +Operator *CreateOperator(OpType type) { + return CreateOperator_(type); +} + +} // namespace op +} // namespace mxnet + diff --git a/src/operator/param.h b/src/operator/param.h new file mode 100644 index 000000000000..0d8016983c5a --- /dev/null +++ b/src/operator/param.h @@ -0,0 +1,73 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file param.h + * \brief operator params + * \author Bing Xu +*/ +#ifndef MXNET_OPERATOR_PARAM_H_ +#define MXNET_OPERATOR_PARAM_H_ + +namespace mxnet { +namespace op { +/*! \brief possible parameter for each operator */ +struct Param { + /*! \brief number of hidden layers */ + int num_hidden; + /*! \brief number of output channel */ + int num_channel; + /*! \brief number of parallel group */ + int num_group; + /*! \brief kernel height */ + int kernel_y; + /*! \brief kernel width */ + int kernel_x; + /*! \brief stride in y dimension*/ + int stride_y; + /*! \brief stride in x dimension */ + int stride_x; + /*! \brief padding in y dimension */ + int pad_y; + /*! \brief padding in x dimension */ + int pad_x; + /*! \brief whether not include bias term */ + int no_bias; + /*! \brief maximum temp_col_size allowed in each layer */ + int temp_col_max; + /*! \brief number of input channels */ + int num_input_channel; + /*! \brief number of input hidden nodes, used by fullc */ + int num_input_node; + /*! \brief reserved fields, for future compatibility */ + int reserved[64]; + inline void SetParam(const char *name, const char* val) { + if (!strcmp(name, "nhidden")) num_hidden = atoi(val); + if (!strcmp(name, "num_input_node")) num_input_node = atoi(val); + if (!strcmp(name, "num_input_channel")) num_input_channel = atoi(val); + if (!strcmp(name, "nchannel")) num_channel = atoi(val); + if (!strcmp(name, "ngroup")) num_group = atoi(val); + if (!strcmp(name, "kernel_size")) { + kernel_y = kernel_x = atoi(val); + } + if (!strcmp(name, "kernel_height")) kernel_y = atoi(val); + if (!strcmp(name, "kernel_width")) kernel_x = atoi(val); + if (!strcmp(name, "stride")) { + stride_y = stride_x = atoi(val); + } + if (!strcmp(name, "stride_y")) stride_y = atoi(val); + if (!strcmp(name, "stride_x")) stride_x = atoi(val); + + if (!strcmp(name, "pad")) { + pad_y = pad_x = atoi(val); + } + if (!strcmp(name, "pad_y")) pad_y = atoi(val); + if (!strcmp(name, "pad_x")) pad_x = atoi(val); + if (!strcmp(name, "no_bias")) no_bias = atoi(val); + if (!strcmp(name, "temp_col_max")) temp_col_max = atoi(val) << 18; + } +}; // struct Param +} // namespace op +} // namespace mxnet + +#endif // MXNET_OPERATOR_PARAM_H_ + +