diff --git a/include/mxnet/base.h b/include/mxnet/base.h index 388ad8c23e90..fe260e082148 100644 --- a/include/mxnet/base.h +++ b/include/mxnet/base.h @@ -41,43 +41,6 @@ typedef mshadow::index_t index_t; /*! \brief data type that will be used to store ndarray */ typedef mshadow::default_real_t real_t; -/*! \brief context information about the execution enviroment */ -struct Context { - /*! \brief the device type we run the op can be cpu::kDevMask or gpu::kDevMask */ - int dev_mask; - /*! \brief device id we are going to run it on */ - int dev_id; - /*! \brief constructor */ - Context() : dev_mask(cpu::kDevMask), dev_id(0) {} - /*! - * \brief constructor of context - * \param dev_mask the device mask - * \param dev_id the device id - */ - Context(int dev_mask, int dev_id) - : dev_mask(dev_mask), dev_id(dev_id) {} - /*! - * \brief check if current context equals another one - * \param b another context to compare - * \return whether dev mask and id are same - */ - inline bool operator==(const Context &b) const { - return dev_mask == b.dev_mask && dev_id == b.dev_id; - } -}; - - -/*! - * \brief execution context provides the information needed - * in runtime to actually execute the operation - */ -struct RunContext { - /*! - * \brief the stream of the device, can be NULL or Stream* in GPU mode - */ - void *stream; -}; - /*! \brief dynamic shape type */ typedef mshadow::TShape TShape; /*! \brief storage container type */ diff --git a/include/mxnet/context.h b/include/mxnet/context.h new file mode 100644 index 000000000000..262ba2e787d4 --- /dev/null +++ b/include/mxnet/context.h @@ -0,0 +1,80 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file context.h + * \brief Context information and resources in mxnet. + */ +#ifndef MXNET_CONTEXT_H_ +#define MXNET_CONTEXT_H_ + +namespace mxnet { + +/*! \brief Context information about the execution enviroment */ +struct Context { + /*! \brief the device type we run the op can be cpu::kDevMask or gpu::kDevMask */ + int dev_mask; + /*! \brief device id we are going to run it on */ + int dev_id; + /*! \brief constructor */ + Context() : dev_mask(cpu::kDevMask), dev_id(0) {} + /*! + * \brief constructor of context + * \param dev_mask the device mask + * \param dev_id the device id + */ + Context(int dev_mask, int dev_id) + : dev_mask(dev_mask), dev_id(dev_id) {} + /*! + * \brief check if current context equals another one + * \param b another context to compare + * \return whether dev mask and id are same + */ + inline bool operator==(const Context &b) const { + return dev_mask == b.dev_mask && dev_id == b.dev_id; + } +}; + +/*! + * \brief execution time context. + * The information needed in runtime for actual execution. + */ +struct RunContext { + /*! + * \brief the stream of the device, can be NULL or Stream* in GPU mode + */ + void *stream; +}; + +/*! + * \brief Additional resources + */ +struct Resource { + /*! \brief Resource type, indicating what the pointer type is */ + enum Type { + /*! \brief mshadow::Random object */ + kRandom, + /*! \brief Temporal space */ + kTempSpace + }; + /*! \brief pointer to the resource */ + void *ptr; +}; + +/*! + * \brief The resources that can be requested by Operator + */ +struct ResourceRequest { + /*! \brief type of resources */ + Resource::Type type; + /*! \brief size requirment if it is an temp space request */ + size_t space_size; + /*! \brief default constructor */ + ResourceRequest() {} + /*! + * \brief default constructor, allow implicit conversion + * \param type type of resources + */ + ResourceRequest(Resource::Type type) : type(type) {} // NOLINT(*) +}; + +} // namespace mxnet +#endif // MXNET_CONTEXT_H_ diff --git a/include/mxnet/dag_engine.h b/include/mxnet/dag_engine.h index 9e65d6108f60..18b804b5a2d8 100644 --- a/include/mxnet/dag_engine.h +++ b/include/mxnet/dag_engine.h @@ -15,6 +15,7 @@ #include #include #include "./base.h" +#include "./context.h" namespace mxnet { /*! diff --git a/include/mxnet/narray.h b/include/mxnet/narray.h index 99829bda92da..c2b6ac3bc882 100644 --- a/include/mxnet/narray.h +++ b/include/mxnet/narray.h @@ -10,7 +10,9 @@ #include #include #include "./base.h" +#include "./context.h" #include "./storage.h" +#include "./context.h" #include "./dag_engine.h" // check c++11 #if DMLC_USE_CXX11 == 0 diff --git a/include/mxnet/operator.h b/include/mxnet/operator.h index 60284c1a5fa3..0299ef2bf167 100644 --- a/include/mxnet/operator.h +++ b/include/mxnet/operator.h @@ -12,14 +12,9 @@ #include #include #include "./base.h" +#include "./context.h" namespace mxnet { -/*! \brief option to pass into the forward function */ -struct Option { - /*! \brief whether it is training phase*/ - int is_train; -}; - /*! \brief operation request type to Forward and Backward */ enum OpReqType { /*! \brief no operation, do not write anything */ @@ -36,6 +31,28 @@ enum OpReqType { kAddTo }; +/*! + * \brief All the possible information needed by Operator.Forward and Backward + * This is the superset of RunContext. + * We use this data structure to bookkeep everything needed by Forward and Backward. + * \sa Resource + */ +struct OpContext { + /*! \brief whether it is training phase */ + int is_train; + /*! \brief Stream we are running on */ + void *stream; + /*! \brief Resources requested by the operator */ + std::vector requested; + /*! + * \brief set the RunContext related parts + * \param ctx the context + */ + inline void SetRunContext(const RunContext &ctx) { + stream = ctx.stream; + } +}; + /*! * \brief Operator interface. * Operator defins basic operation unit of optimized computation graph in mxnet. @@ -54,30 +71,28 @@ class Operator { virtual ~Operator() {} /*! * \brief perform a forward operation of Operator, save the output to TBlob. - * \param opt option on Forward such as whether this is training phase. - * \param ctx runtime context + * \param ctx runtime context available to this call * \param in_data array of input data, it is const * \param req the request types of saving operation, can only be kWriteTo or kWriteInplace. * \param out_data array of output data, pointer is used to indicate that this is holder * the space of TBlob in out_data must be pre-allocated with InferShape - * \sa OpReqType + * \sa OpReqType, OpContext */ - virtual void Forward(Option opt, - RunContext ctx, + virtual void Forward(const OpContext &ctx, const std::vector &in_data, const std::vector &req, const std::vector &out_data) = 0; /*! - * \brief Perform a backward Operation, write gradient to the in_grad. - * \param ctx runtime context + * \brief Perform a Backward Operation, write gradient to the in_grad. + * \param ctx runtime context available to this call * \param out_grad the gradient value we get from output of the Operator * \param in_data the array of input data. * \param out_data the array of output data. * \param req request types of the saving operation, can be all types. * \param in_grad the array of gradient we need to write to. - * \sa OpReqType + * \sa OpReqType, OpContext */ - virtual void Backward(RunContext ctx, + virtual void Backward(const OpContext &ctx, const std::vector &out_grad, const std::vector &in_data, const std::vector &out_data, @@ -114,10 +129,25 @@ class OperatorProperty { virtual std::vector ListReturns() const { return {"output"}; } - /*! \return number of outputs of the Operator */ + /*! \return number of real return values of the Operator */ virtual int NumReturns() const { return 1; } + /*! + * \brief get number of visible return values during Symbol creation. + * If NumVisibleReturns() = k, and NumReturns() = n. + * The first k returns will be presented in the resulting symbol. + * + * The rest of the returns can be used for auxiliary states for Backward. + * For example, Dropout will return [data, mask], with NumVisibleReturns() == 1. + * So when user call sym = Dropout(input), only data is presented in sym. + * But all the returns will be presented in out_data parameter of Backward if requested. + * + * \return number of default return values + */ + virtual int NumVisibleReturns() const { + return NumReturns(); + } /*! * \brief Set the parameters of the Operator. * \param name parameter name @@ -154,6 +184,27 @@ class OperatorProperty { * subclasses override this function. */ virtual std::string TypeString() const = 0; + //-------------------------------------------------------- + // All the below functions are optional to override. + //-------------------------------------------------------- + /*! + * \brief Declare additional resource required in forward pass. + * These additional resources will be presented in OpContext.requested + * in the same order of the returned Resource. + * \return Additional resource request + */ + virtual std::vector ForwardResource() const { + return std::vector(); + } + /*! + * \brief Decalre additional resource required in backward pass. + * These additional resources will be presented in OpContext.requested + * in the same order of the returned Resource. + * \return Additional resource request + */ + virtual std::vector BackwardResource() const { + return std::vector(); + } /*! * \brief Declare the input requirement of Backward pass. * diff --git a/include/mxnet/storage.h b/include/mxnet/storage.h index 6afc2885a746..2953cbe0d171 100644 --- a/include/mxnet/storage.h +++ b/include/mxnet/storage.h @@ -6,6 +6,7 @@ #ifndef MXNET_STORAGE_H_ #define MXNET_STORAGE_H_ #include "./base.h" +#include "./context.h" namespace mxnet { /*! \brief memory allocator of storage */ diff --git a/src/narray/narray_op.h b/src/narray/narray_op.h index 2c39363fba32..21a8da782972 100644 --- a/src/narray/narray_op.h +++ b/src/narray/narray_op.h @@ -8,6 +8,7 @@ #include #include #include +#include namespace mxnet { /*! \brief namespace to support all possible NArray operator */ diff --git a/src/operator/fully_connected-inl.h b/src/operator/fully_connected-inl.h index 92a95a2ada2c..5c54d37220ee 100644 --- a/src/operator/fully_connected-inl.h +++ b/src/operator/fully_connected-inl.h @@ -33,8 +33,7 @@ class FullyConnectedOp : public Operator { this->param_ = p; } - virtual void Forward(Option opt, - RunContext ctx, + virtual void Forward(const OpContext &ctx, const std::vector &in_data, const std::vector &req, const std::vector &out_data) { @@ -57,7 +56,7 @@ class FullyConnectedOp : public Operator { } } - virtual void Backward(RunContext ctx, + virtual void Backward(const OpContext &ctx, const std::vector &out_grad, const std::vector &in_data, const std::vector &out_data, diff --git a/src/symbol/symbol.cc b/src/symbol/symbol.cc index d7a56528fb77..86cf54feabfa 100644 --- a/src/symbol/symbol.cc +++ b/src/symbol/symbol.cc @@ -368,7 +368,7 @@ bool Symbol::InferShape(std::vector *in_shape, Symbol Symbol::Create(OperatorProperty *op) { // use special representation for atomic symbol auto node = std::make_shared(op, ""); - size_t nret = op->NumReturns(); + size_t nret = op->NumVisibleReturns(); Symbol s; for (uint32_t i = 0; i < nret; ++i) { s.heads_.push_back(DataEntry(node, i));