diff --git a/Makefile b/Makefile index 33dc31fdddbc..fb7ac440a182 100644 --- a/Makefile +++ b/Makefile @@ -56,7 +56,7 @@ endif #BIN = test/test_threaded_engine test/api_registry_test BIN = test/api_registry_test -OBJ = storage.o narray_op_cpu.o static_operator.o static_operator_cpu.o +OBJ = storage.o narray_op_cpu.o static_operator.o static_operator_cpu.o atomic_symbol_cpu.o # add threaded engine after it is done OBJCXX11 = engine.o narray.o mxnet_api.o registry.o symbol.o operator.o CUOBJ = @@ -65,7 +65,7 @@ ALIB = api/libmxnet.a LIB_DEP = $(DMLC_CORE)/libdmlc.a ifeq ($(USE_CUDA), 1) - CUOBJ += narray_op_gpu.o static_operator_gpu.o + CUOBJ += narray_op_gpu.o static_operator_gpu.o atomic_symbol_gpu.o endif .PHONY: clean all test lint doc @@ -87,7 +87,9 @@ static_operator_gpu.o: src/static_operator/static_operator_gpu.cu symbol.o: src/symbol/symbol.cc registry.o: src/registry.cc mxnet_api.o: api/mxnet_api.cc -operator.o: src/operator/operator.cc +operator.o: src/operator/static_operator_wrapper.cc +atomic_symbol_cpu.o: src/symbol/fully_connect_sym.cc +atomic_symbol_gpu.o: src/symbol/fully_connect_sym.cu api/libmxnet.a: $(OBJ) $(OBJCXX11) $(CUOBJ) api/libmxnet.so: $(OBJ) $(OBJCXX11) $(CUOBJ) diff --git a/include/mxnet/atomic_symbol.h b/include/mxnet/atomic_symbol.h index 8d4019f87d62..842fbe6d0198 100644 --- a/include/mxnet/atomic_symbol.h +++ b/include/mxnet/atomic_symbol.h @@ -14,7 +14,7 @@ #include "./tensor_blob.h" namespace mxnet { -class Operator; +class StaticOperator; /*! * \brief AtomicSymbol is the base class of all atomic symbols. * This is not meant to be used by user, it should be wrapped in Symbol, so that the same instance @@ -54,7 +54,7 @@ class AtomicSymbol { * InferShape will modify the vector to fill output TShape * \return if the shape inference is successful, return true, else return false. */ - virtual bool InferShape(std::vector *in_shape, std::vector *out_shape) = 0; + virtual bool InferShape(std::vector *in_shape, std::vector *out_shape) const = 0; /*! * \brief Copy this AtomicSymbol and returns a pointer to the copied object. * this is a virtual function because different subclass of AtomicSymbol would copy differently. @@ -66,7 +66,8 @@ class AtomicSymbol { * Bind function of AtomicSymbol does not return NArrayOperator, but static operator. * Calling bind from the Symbol wrapper would generate a NArrayOperator. */ - virtual Operator* Bind(Context ctx) const = 0; + template + StaticOperator* Bind(Context ctx) const; /*! * \brief return the type string of the atomic symbol * subclasses override this function. diff --git a/include/mxnet/narray.h b/include/mxnet/narray.h index 458bf9f6c834..4e7b4448e667 100644 --- a/include/mxnet/narray.h +++ b/include/mxnet/narray.h @@ -73,7 +73,7 @@ class NArray { DAGEngine::Get()->WaitForVar(ptr_->var); } /*! \return the associated DAG variable of the narray.*/ - inline DAGEngine::Variable Var() const { + inline DAGEngine::Variable var() const { return ptr_->var; } /*! diff --git a/include/mxnet/operator.h b/include/mxnet/operator.h index 4dbcea798cdd..ea1e44990e9d 100644 --- a/include/mxnet/operator.h +++ b/include/mxnet/operator.h @@ -17,93 +17,63 @@ namespace mxnet { /*! - * \brief static operator interface (current interface have not yet todo with scheduler), - * operator is a stateful object that can be used to call forward and backprop - * - * This interface relies on pre-allocated memory in TBlob, the caller need to set - * the memory region in TBlob correctly before calling Forward and Backward + * \brief operator interface + * operator is an object can be scheduled by DAG engine directly. * + * This interface relies on NArray. The user should prepare the input NArray and + * output NArray by themselves. * \sa Operator */ class Operator { public: - /*! - * \brief construct Operator from StaticOperator and Context - * \param op StaticOperator to wrap - * \param ctx Context of the Operator - */ - Operator(StaticOperator* op, Context ctx); - /*! - * \brief get types of input argument of this oeprator - * \return a vector corresponding to type of each argument - * this order is same as the order of inputs in Forward, InferShape and Backward - */ - virtual std::vector DescribeArgs() const; /*! * \brief describe property of op * \return a bit map in int */ virtual int DescribeProperty() const; /*! - * \brief set param for the operator from string - * \param name parameter name - * \param val string for configuration - */ - virtual void SetParam(const char *name, const char *val); - /*! - * \brief inter the shapes of outputs and unknown input arguments - * \param in_shape the shape of input arguments of the operator - * this should be of same length as the vector returned by DescribeArgs - * in_shape allows unknown elements, which are checked by shape.ndim() == 0. - * For unknown shapes, InferShape will try to fill in the correct Shape in in_shape - * For known shapes, InferShape will check shape consistency - * - * common practice: set the shape of data input, and usually weight's shape can be infered - * - * \param out_shape the shape of outputs of the operator - * InferShape will modify the vector to fill output TShape - */ - virtual void InferShape(std::vector *in_shape, - std::vector *out_shape); - - /*! - * \brief set the context of the Operator + * \brief set the global context of the Operator * \param ctx the context to be set to */ virtual void SetContext(Context ctx); /*! - * \brief perform a forward operation of operator, save the output to TBlob + * \brief perform a forward operation of operator, save the output to NArray + * This method only pushes an execution request to the DAG engine, and + * return immediately. Actual execution is conducted by the DAG engine. * \param opt option on Forward such as whether this is training phase * \param ctx runtime context * \param in_data array of input data, it is const * \param out_data array of output data, - * the space of TBlob in out_data must be pre-allocated with InferShape + * the space of NArray in out_data must be pre-allocated with InferShape + * \sa NArray */ virtual void Forward(Option opt, RunContext ctx, const std::vector &in_data, - const std::vector &out_data); + const std::vector &out_data) = 0; /*! * \brief perform a backward operation of the operator to get the gradient + * This method only pushes an execution request to the DAG engine, and + * return immediately. Actual execution is conducted by the DAG engine. * \param ctx runtime context - * \param grad_next the gradient value we get from output of the operator + * \param grad_next the gradient value of the output of the operator, used by chain rule. * \param in_data the array of input data - * \param out_grad array of output gradient, there could be three possible TBlob - * in the each element in the array + * \param out_grad array of output gradient * \param req request types of the gradient saving operation * only inplace will change input data - * \sa GradReqType + * \sa GradReqType, NArray */ virtual void Backward(RunContext ctx, const std::vector &grad_next, const std::vector &in_data, const std::vector &out_grad, - const std::vector &req); + const std::vector &req) = 0; - private: - /* \brief the static operator */ - StaticOperator* op; + protected: + /** + * \brief the global context denots the device info. + */ Context global_ctx; -}; +}; // class operator } // namespace mxnet #endif // MXNET_OPERATOR_H_ diff --git a/include/mxnet/static_operator.h b/include/mxnet/static_operator.h index b2f9e49af154..2f989bf5aca8 100644 --- a/include/mxnet/static_operator.h +++ b/include/mxnet/static_operator.h @@ -23,15 +23,6 @@ namespace mxnet { */ class StaticOperator { public: - /*! - * \brief get types of input argument of this oeprator - * \return a vector corresponding to type of each argument - * this order is same as the order of inputs in Forward, InferShape and Backward - */ - virtual std::vector DescribeArgs() const { - // default most of layers only have one data argument - return std::vector(1, kDataArg); - } /*! * \brief describe property of op * \return a bit map in int @@ -40,27 +31,6 @@ class StaticOperator { // default most of layer only conatin internal state return kContainInteralState; } - /*! - * \brief set param for the StaticOperator from string - * \param name parameter name - * \param val string for configuration - */ - virtual void SetParam(const char *name, const char *val) {} - /*! - * \brief inter the shapes of outputs and unknown input arguments - * \param in_shape the shape of input arguments of the StaticOperator - * this should be of same length as the vector returned by DescribeArgs - * in_shape allows unknown elements, which are checked by shape.ndim() == 0. - * For unknown shapes, InferShape will try to fill in the correct Shape in in_shape - * For known shapes, InferShape will check shape consistency - * - * common practice: set the shape of data input, and usually weight's shape can be infered - * - * \param out_shape the shape of outputs of the StaticOperator - * InferShape will modify the vector to fill output TShape - */ - virtual void InferShape(std::vector *in_shape, - std::vector *out_shape) = 0; /*! * \brief perform a forward operation of StaticOperator, save the output to TBlob * \param opt option on Forward such as whether this is training phase @@ -90,7 +60,7 @@ class StaticOperator { const std::vector &out_grad, const std::vector &req) = 0; /*! - * \brief factory unction, create a new StaticOperator + * \brief factory function, create a new StaticOperator * \param type the type of StaticOperator * \param ctx the context device type of StaticOperator * \return a pointer of StaticOperator object diff --git a/include/mxnet/symbol.h b/include/mxnet/symbol.h index 6f0b5ef8af93..0b69005f7a16 100644 --- a/include/mxnet/symbol.h +++ b/include/mxnet/symbol.h @@ -15,9 +15,9 @@ #include #include "./base.h" #include "./tensor_blob.h" +#include "./operator.h" namespace mxnet { -class NArrayOperator; /*! * \brief Symbol is the wrapper of AtomicSymbol, the reason for this indirection is that Symbol * should support expressions and often passed by value. While AtomicSymbol have many subclasses, @@ -67,11 +67,11 @@ class Symbol { */ virtual ~Symbol() {} /*! - * \brief bind to device and returns an NArrayOperator. + * \brief bind to device and returns an operator. * \param ctx context of the operator - * \return returns the pointer to a created NArrayOperator. It is on the user to delete. + * \return returns the pointer to a created operator. It is on the user to delete. */ - virtual NArrayOperator* Bind(Context ctx) const { return nullptr; } + virtual Operator* Bind(Context ctx) const { return nullptr; } /*! * \brief copy the symbol * \return a deep copy of the graph diff --git a/src/operator/operator.cc b/src/operator/static_operator_wrapper.cc similarity index 60% rename from src/operator/operator.cc rename to src/operator/static_operator_wrapper.cc index ccfc640d3a3d..7304e071461d 100644 --- a/src/operator/operator.cc +++ b/src/operator/static_operator_wrapper.cc @@ -4,56 +4,22 @@ * \brief the implementation of narray operator * \author Naiyan Wang */ -#include +#include "./static_operator_wrapper.h" namespace mxnet { - Operator::Operator(StaticOperator* op, Context ctx) { + StaticOperatorWrapper::StaticOperatorWrapper(StaticOperator* op, Context ctx) { this->op = op; this->global_ctx = ctx; } - /*! - * \brief get types of input argument of this oeprator - * \return a vector corresponding to type of each argument - * this order is same as the order of inputs in Forward, InferShape and Backward - */ - std::vector Operator::DescribeArgs() const { - // default most of layers only have one data argument - return op->DescribeArgs(); - } /*! * \brief describe property of op * \return a bit map in int */ - int Operator::DescribeProperty() const { + int StaticOperatorWrapper::DescribeProperty() const { // default most of layer only conatin internal state return op->DescribeProperty(); } - /*! - * \brief set param for the operator from string - * \param name parameter name - * \param val string for configuration - */ - void Operator::SetParam(const char *name, const char *val) { - op->SetParam(name, val); - } - /*! - * \brief inter the shapes of outputs and unknown input arguments - * \param in_shape the shape of input arguments of the operator - * this should be of same length as the vector returned by DescribeArgs - * in_shape allows unknown elements, which are checked by shape.ndim() == 0. - * For unknown shapes, InferShape will try to fill in the correct Shape in in_shape - * For known shapes, InferShape will check shape consistency - * - * common practice: set the shape of data input, and usually weight's shape can be infered - * - * \param out_shape the shape of outputs of the operator - * InferShape will modify the vector to fill output TShape - */ - void Operator::InferShape(std::vector *in_shape, - std::vector *out_shape) { - op->InferShape(in_shape, out_shape); - } /*! * \brief perform a forward operation of operator, save the output to TBlob * \param opt option on Forward such as whether this is training phase @@ -62,7 +28,7 @@ namespace mxnet { * \param out_data array of output data, * the space of TBlob in out_data must be pre-allocated with InferShape */ - void Operator::Forward(Option opt, + void StaticOperatorWrapper::Forward(Option opt, RunContext ctx, const std::vector &in_data, const std::vector &out_data) { @@ -71,11 +37,11 @@ namespace mxnet { std::vector in; std::vector out; for (size_t i = 0; i < in_data.size(); ++i) { - used_var.push_back(in_data[i].Var()); + used_var.push_back(in_data[i].var()); in.push_back(in_data[i].data()); } for (size_t i = 0; i < out_data.size(); ++i) { - mutate_var.push_back(out_data[i].Var()); + mutate_var.push_back(out_data[i].var()); out.push_back(out_data[i].data()); } DAGEngine::Get()->Push([this, opt, ctx, in, out](RunContext ctx) { @@ -93,7 +59,7 @@ namespace mxnet { * only inplace will change input data * \sa GradReqType */ - void Operator::Backward(RunContext ctx, + void StaticOperatorWrapper::Backward(RunContext ctx, const std::vector &grad_next, const std::vector &in_data, const std::vector &out_grad, @@ -104,15 +70,15 @@ namespace mxnet { std::vector grad_out; std::vector data; for (size_t i = 0; i < grad_next.size(); ++i) { - used_var.push_back(grad_next[i].Var()); + used_var.push_back(grad_next[i].var()); grad_in.push_back(grad_next[i].data()); } for (size_t i = 0; i < in_data.size(); ++i) { - used_var.push_back(in_data[i].Var()); + used_var.push_back(in_data[i].var()); data.push_back(in_data[i].data()); } for (size_t i = 0; i < out_grad.size(); ++i) { - mutate_var.push_back(out_grad[i].Var()); + mutate_var.push_back(out_grad[i].var()); grad_out.push_back(out_grad[i].data()); } DAGEngine::Get()->Push([this, ctx, grad_in, grad_out, data, req](RunContext ctx) { @@ -120,7 +86,7 @@ namespace mxnet { }, global_ctx, used_var, mutate_var); } - void Operator::SetContext(Context ctx) { + void StaticOperatorWrapper::SetContext(Context ctx) { this->global_ctx = ctx; } diff --git a/src/operator/static_operator_wrapper.h b/src/operator/static_operator_wrapper.h new file mode 100644 index 000000000000..6897ff9751f6 --- /dev/null +++ b/src/operator/static_operator_wrapper.h @@ -0,0 +1,81 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file static_operator_wrapper.h + * \brief operator interface of mxnet + * \author Naiyan Wang + */ +#ifndef MXNET_OPERATOR_STATIC_OPERATOR_WRAPPER_H_ +#define MXNET_OPERATOR_STATIC_OPERATOR_WRAPPER_H_ +// this file will be seen by cuda, no c++11 for now +#include +#include +#include +#include +#include +#include +#include +#include + +namespace mxnet { +/*! + * \brief static operator interface (current interface have not yet todo with scheduler), + * operator is a stateful object that can be used to call forward and backprop + * + * This interface relies on pre-allocated memory in TBlob, the caller need to set + * the memory region in TBlob correctly before calling Forward and Backward + * + * \sa Operator + */ +class StaticOperatorWrapper: public Operator { + public: + /*! + * \brief construct Operator from StaticOperator and Context + * \param op StaticOperator to wrap + * \param ctx Context of the Operator + */ + StaticOperatorWrapper(StaticOperator* op, Context ctx); + /*! + * \brief describe property of op + * \return a bit map in int + */ + virtual int DescribeProperty() const; + /*! + * \brief set the context of the Operator + * \param ctx the context to be set to + */ + virtual void SetContext(Context ctx); + /*! + * \brief perform a forward operation of operator, save the output to TBlob + * \param opt option on Forward such as whether this is training phase + * \param ctx runtime context + * \param in_data array of input data, it is const + * \param out_data array of output data, + * the space of TBlob in out_data must be pre-allocated with InferShape + */ + virtual void Forward(Option opt, + RunContext ctx, + const std::vector &in_data, + const std::vector &out_data); + /*! + * \brief perform a backward operation of the operator to get the gradient + * \param ctx runtime context + * \param grad_next the gradient value we get from output of the operator + * \param in_data the array of input data + * \param out_grad array of output gradient, there could be three possible TBlob + * in the each element in the array + * \param req request types of the gradient saving operation + * only inplace will change input data + * \sa GradReqType + */ + virtual void Backward(RunContext ctx, + const std::vector &grad_next, + const std::vector &in_data, + const std::vector &out_grad, + const std::vector &req); + + private: + /* \brief the static operator */ + StaticOperator* op; +}; +} // namespace mxnet +#endif // MXNET_OPERATOR_STATIC_OPERATOR_WRAPPER_H_ diff --git a/src/static_operator/fully_connect_op-inl.h b/src/static_operator/fully_connect_op-inl.h index de0250f101bd..7cbf05653903 100644 --- a/src/static_operator/fully_connect_op-inl.h +++ b/src/static_operator/fully_connect_op-inl.h @@ -1,55 +1,43 @@ /*! * Copyright (c) 2015 by Contributors * \file fully_connect_op-inl.h - * \brief fully connect operator - * \author Bing Xu + * \brief fully connect operator and symbol */ #ifndef MXNET_STATIC_OPERATOR_FULLY_CONNECT_OP_INL_H_ #define MXNET_STATIC_OPERATOR_FULLY_CONNECT_OP_INL_H_ #include #include +#include #include +#include #include "./static_operator_common.h" #include "./param.h" namespace mxnet { namespace op { +/** + * \brief This is the implementation of fully connected layer. + * + * \tparam xpu The device that the op will be executed on. + */ template class FullyConnectOp : public StaticOperator { public: - virtual std::vector DescribeArgs() const { - ArgType ret[] = {kDataArg, kWeightArg, kBiasArg}; - if (param_.no_bias == 0) { - return std::vector(ret, ret + 3); - } else { - return std::vector(ret, ret + 2); - } - } - virtual void SetParam(const char *name, const char *val) { - param_.SetParam(name, val); + /*! + * \brief default constructor + */ + FullyConnectOp() { + // Do nothing. } - virtual void InferShape(std::vector *in_shape, - std::vector *out_shape) { - using namespace mshadow; - if (param_.no_bias == 0) { - CHECK_EQ(in_shape->size(), 3) << "Input:[data, weight, bias]"; - } else { - CHECK_EQ(in_shape->size(), 2) << "Input:[data, weight]"; - } - CHECK_GT(param_.num_hidden, 0); - const TShape &dshape = (*in_shape)[0]; - CHECK_EQ(dshape.ndim(), 4) << \ - "Input data should be 4D in batch-1-1-hidden"; - CHECK_NE(dshape.ndim(), 0) << "Require data shape to be known"; - ShapeAssignCheck((*in_shape)[1], Shape2(param_.num_hidden, dshape[3])); - if (param_.no_bias == 0) { - ShapeAssignCheck((*in_shape)[2], Shape1(param_.num_hidden)); - } - out_shape->clear(); - out_shape->push_back(dshape); - (*out_shape)[0][3] = param_.num_hidden; + + /*! + * \brief constructor with parameters. Used in Bind() in corresponding symbol. + */ + explicit FullyConnectOp(Param p) { + this->param_ = p; } + virtual void Forward(Option opt, RunContext ctx, const std::vector &in_data, @@ -71,6 +59,7 @@ class FullyConnectOp : public StaticOperator { out += repmat(bias, data.size(0)); } } + virtual void Backward(RunContext ctx, const std::vector &grad_next, const std::vector &in_data, @@ -104,8 +93,44 @@ class FullyConnectOp : public StaticOperator { } private: + /** The param of the fully connected layer.*/ Param param_; }; // class FullyConnectOp + +/** + * @brief The symbol part of the fully connected layer. + */ +class FullyConnectSymbol : public AtomicSymbol { + public: + virtual std::vector DescribeArguments() const; + + virtual std::vector DescribeReturns() const; + + virtual void SetParam(const char *name, const char *val); + + virtual bool InferShape(std::vector *in_shape, + std::vector *out_shape) const; + + virtual AtomicSymbol* Copy() const; + + StaticOperator* Bind(Context ctx) const; + + virtual std::string TypeString() const; + + /** + * @brief This is the template function of bind() implementation. + * + * @param ctx The device context + * @return A device dependent static operator can be used for execution. + */ + template + StaticOperator* Bind_(Context ctx) const; + + private: + /** The param of the fully connected layer.*/ + Param param_; +}; // class FullyConnectSymbol + } // namespace op } // namespace mxnet diff --git a/src/static_operator/static_operator.cc b/src/static_operator/static_operator.cc index 67464fb394b6..4a2a121532dd 100644 --- a/src/static_operator/static_operator.cc +++ b/src/static_operator/static_operator.cc @@ -12,11 +12,12 @@ namespace mxnet { namespace op { -// declare the operator -template -StaticOperator *CreateOperator(OpType type); - - +/** + * @brief return a OpType based on string description + * + * @param type the string description of operators + * @return the OpType indicated the type of operators + */ OpType GetOpType(const char *type) { if (!strcmp(type, "relu")) return kReLU; if (!strcmp(type, "fullc")) return kFullc; @@ -25,7 +26,6 @@ OpType GetOpType(const char *type) { } } // namespace op -// implementing the context StaticOperator *StaticOperator::Create(const char *type, Context ctx) { op::OpType otype = op::GetOpType(type); diff --git a/src/symbol/fully_connect_sym.cc b/src/symbol/fully_connect_sym.cc new file mode 100644 index 000000000000..6ef21437b43e --- /dev/null +++ b/src/symbol/fully_connect_sym.cc @@ -0,0 +1,79 @@ + /*! + * Copyright (c) 2015 by Contributors + * \file fully_connect_sym.cc + * \brief fully connect operator symbol +*/ +#include "../static_operator/fully_connect_op-inl.h" + +namespace mxnet { +namespace op { + std::vector FullyConnectSymbol::DescribeArguments() const { + std::string ret[] = {"data", "weight", "bias"}; + if (param_.no_bias == 0) { + return std::vector(ret, ret + 3); + } else { + return std::vector(ret, ret + 2); + } + } + + std::vector FullyConnectSymbol::DescribeReturns() const { + std::string temp = "output"; + std::vector v; + v.push_back(temp); + return v; + } + + void FullyConnectSymbol::SetParam(const char *name, const char *val) { + param_.SetParam(name, val); + } + + bool FullyConnectSymbol::InferShape(std::vector *in_shape, + std::vector *out_shape) const { + using namespace mshadow; + if (param_.no_bias == 0) { + CHECK_EQ(in_shape->size(), 3) << "Input:[data, weight, bias]"; + } else { + CHECK_EQ(in_shape->size(), 2) << "Input:[data, weight]"; + } + CHECK_GT(param_.num_hidden, 0); + const TShape &dshape = (*in_shape)[0]; + CHECK_EQ(dshape.ndim(), 4) << \ + "Input data should be 4D in batch-1-1-hidden"; + CHECK_NE(dshape.ndim(), 0) << "Require data shape to be known"; + ShapeAssignCheck((*in_shape)[1], Shape2(param_.num_hidden, dshape[3])); + if (param_.no_bias == 0) { + ShapeAssignCheck((*in_shape)[2], Shape1(param_.num_hidden)); + } + out_shape->clear(); + out_shape->push_back(dshape); + (*out_shape)[0][3] = param_.num_hidden; + return true; + } + + AtomicSymbol* FullyConnectSymbol::Copy() const { + FullyConnectSymbol* fc_sym = new FullyConnectSymbol(); + fc_sym->param_ = this->param_; + return fc_sym; + } + std::string FullyConnectSymbol::TypeString() const { + return "fully_connected"; + } + + template<> + StaticOperator* FullyConnectSymbol::Bind_(Context ctx) const { + return new FullyConnectOp(param_); + } + + StaticOperator* FullyConnectSymbol::Bind(Context ctx) const { + if (ctx.dev_mask == cpu::kDevMask) { + return Bind_(ctx); + } else { +#if MXNET_USE_CUDA + return Bind_(ctx); +#else + LOG(FATAL) << "GPU is not enabled"; +#endif + } + } +} // namespace op +} // namespace mxnet diff --git a/src/symbol/fully_connect_sym.cu b/src/symbol/fully_connect_sym.cu new file mode 100644 index 000000000000..6fcabf019d1a --- /dev/null +++ b/src/symbol/fully_connect_sym.cu @@ -0,0 +1,15 @@ + /*! + * Copyright (c) 2015 by Contributors + * \file fully_connect_sym.cu + * \brief fully connect operator symbol +*/ +#include "../static_operator/fully_connect_op-inl.h" + +namespace mxnet { +namespace op { + template<> + StaticOperator* FullyConnectSymbol::Bind_(Context ctx) const { + return new FullyConnectOp(param_); + } +} // namespace op +} // namespace mxnet \ No newline at end of file